Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e94e888df3 | |||
| be73ad20a5 | |||
| f02a7c2976 | |||
| 331a1f5b3f | |||
| 8e345c5c5b | |||
| 81a43e6d92 | |||
| ade6376fa0 | |||
| bb4dd682dd | |||
| 5e497243f7 | |||
| b3f3c7758c | |||
| 9e1b649827 | |||
| 5120b21cc3 | |||
| dd76dec4ef | |||
| 19cc2a5feb | |||
| 09df6ac464 | |||
| df8a550d39 | |||
| 79fc51f4b8 | |||
| 6f4921858b | |||
| 62750a2b75 | |||
| 8c4d1dc47d | |||
| 3fe62887d8 | |||
| bd03b22f64 | |||
| 6c6b78550e | |||
| 06e560d98a | |||
| df18f5e4f5 | |||
| ca4fdbea70 | |||
| eefa171318 | |||
| afa1772203 | |||
| 9b3772dfa6 | |||
| b84e9802d8 | |||
| e9627ce55b | |||
| ad6e1ec19c | |||
| 0642d46dd4 | |||
| 833f6990e0 | |||
| affd1b693d | |||
| 6f55278121 | |||
| 3c28697b9f | |||
| bdd641790a | |||
| cc19d4d22b | |||
| 47daa33c61 | |||
| 389e493055 | |||
| 9eb01fa0b0 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
# PyCache files
|
||||
__pycache__/
|
||||
cutlass_library.egg-info/
|
||||
cutlass_library.egg-info/
|
||||
/build*
|
||||
|
||||
157
CHANGELOG.md
157
CHANGELOG.md
@ -1,9 +1,120 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
|
||||
## [3.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.0) (2025-04-24)
|
||||
|
||||
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
|
||||
- Collective mainloops that target for:
|
||||
* [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Support for Blackwell SM100 Sparse kernels:
|
||||
- Collective mainloop that target for
|
||||
* [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
|
||||
- [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
* Optimal code generation with CUDA toolkit versions 12.8U1.
|
||||
|
||||
## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25)
|
||||
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
- Support for mixed input GEMM kernels on Hopper in the profiler.
|
||||
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
|
||||
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
- A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
|
||||
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md)
|
||||
- A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
- Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper.
|
||||
|
||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||
- [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
|
||||
- Potential API breaking changes:
|
||||
+ Fix `cute::UniversalCopy` for type safety.
|
||||
+ No longer implicitly select `cute::SM80_CP_ASYNC_*` based on input tensors. This avoids implicit downstream synchronization requirements. To use `SM80_CP_ASYNC`, users must explicitly select the appropriate CopyAtom.
|
||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
@ -15,19 +126,14 @@
|
||||
+ [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- Improve [mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md).
|
||||
+ Added a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
+ Added [layout pre-shuffling](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L50-55) to optimize memory loading.
|
||||
+ Added [interleaved conversion](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu#L50-52) for `{INT4, UINT4, INT8}` x `{FP16, BF16}`.
|
||||
+ Other general optimizations.
|
||||
- The suffixes of the mixed input kernel schedules have been removed. Use `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` and `KernelTmaWarpSpecializedCooperative` instead.
|
||||
- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
@ -37,7 +143,7 @@
|
||||
|
||||
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu)
|
||||
- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48)
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/cpp/profiler.md#GEMM), and
|
||||
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
|
||||
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
|
||||
@ -50,7 +156,7 @@
|
||||
- Support for residual add (beta != 0) in convolution kernels.
|
||||
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md).
|
||||
- Better support for MSVC as a host compiler.
|
||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||
@ -58,7 +164,7 @@
|
||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
@ -70,7 +176,7 @@
|
||||
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
||||
+ [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
|
||||
+ [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cpp/cute/03_tensor.md), [MMA atoms](./media/docs/cpp/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
|
||||
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
|
||||
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
|
||||
- Fixes to greatly reduce build warnings.
|
||||
@ -89,7 +195,7 @@
|
||||
* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
* Improved [CuTe documentation](./media/docs/cpp/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cpp/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cpp/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
@ -140,7 +246,7 @@
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
@ -152,10 +258,10 @@
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
|
||||
* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* [CuTe](./media/docs/cpp/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](./media/docs/cpp/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/cpp/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cpp/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/cpp/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
|
||||
@ -333,7 +439,7 @@
|
||||
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
|
||||
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
|
||||
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
|
||||
* [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
* [Documentation](./media/docs/cpp/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
@ -347,7 +453,7 @@
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/cpp/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
@ -367,7 +473,7 @@
|
||||
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/cpp/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
@ -381,10 +487,10 @@
|
||||
* Encapsulated functionality embodying modern C++11 programming techniques
|
||||
* Optimized containers and data types for efficient, generic, portable device code
|
||||
* Updates to:
|
||||
* [Quick start guide](./media/docs/quickstart.md)
|
||||
* [Quick start guide](./media/docs/cpp/quickstart.md)
|
||||
* [Documentation](./README.md#documentation)
|
||||
* [Utilities](./media/docs/utilities.md)
|
||||
* [CUTLASS Profiler](./media/docs/profiler.md)
|
||||
* [Utilities](./media/docs/cpp/utilities.md)
|
||||
* [CUTLASS Profiler](./media/docs/cpp/profiler.md)
|
||||
* Native Turing Tensor Cores
|
||||
* Efficient GEMM kernels targeting Turing Tensor Cores
|
||||
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
|
||||
@ -477,4 +583,3 @@ SPDX-License-Identifier: BSD-3-Clause
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
177
CMakeLists.txt
177
CMakeLists.txt
@ -102,6 +102,8 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -ftemplate-backtrace-limit=0)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
|
||||
endif()
|
||||
@ -114,6 +116,13 @@ set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
################################################################################
|
||||
|
||||
|
||||
include(customConfigs.cmake)
|
||||
|
||||
################################################################################
|
||||
|
||||
|
||||
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
|
||||
|
||||
if(CUTLASS_ENABLE_HEADERS_ONLY)
|
||||
@ -143,14 +152,14 @@ set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUT
|
||||
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
|
||||
set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests")
|
||||
set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest")
|
||||
set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.")
|
||||
if (CUTLASS_USE_PACKED_TUPLE)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_USE_PACKED_TUPLE=1")
|
||||
message(STATUS "Make cute::tuple be the new standard-layout tuple type")
|
||||
elseif()
|
||||
message(STATUS "Use the original cute::tuple implementation that is _not_ standard-layout")
|
||||
|
||||
if (CUTLASS_ENABLE_TESTS AND CUTLASS_ENABLE_PROFILER)
|
||||
set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT ON)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT OFF)
|
||||
endif()
|
||||
set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS ${CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT} CACHE BOOL "Enable CUTLASS Profiler-based Unit Tests")
|
||||
set(CUTLASS_ENABLE_SELF_CONTAINED_INCLUDES_CHECK ON CACHE BOOL "Enable CUTLASS check for self-contained header includes")
|
||||
|
||||
################################################################################
|
||||
|
||||
@ -164,6 +173,11 @@ endif()
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 101 101a 120 120a)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
@ -370,7 +384,21 @@ endif()
|
||||
|
||||
if (CUTLASS_ENABLE_GDC_FOR_SM90)
|
||||
message(STATUS "Grid Dependency Control (GDC) is enabled for SM90 kernels (required for programmatic dependent launches).")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1)
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1)
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT)
|
||||
set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_GDC_FOR_SM100
|
||||
${CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT}
|
||||
CACHE BOOL
|
||||
"Enables Grid Dependency Control (GDC) for SM100 kernels (required for PDL).")
|
||||
|
||||
if (CUTLASS_ENABLE_GDC_FOR_SM100)
|
||||
message(STATUS "Grid Dependency Control (GDC) is enabled for SM100 kernels (required for programmatic dependent launches).")
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM100=1)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_SYNCLOG OFF CACHE BOOL "Enable synchronization event logging for race condition debugging. WARNING: This redefines __syncthreads() and __syncwarp() in all downstream code!")
|
||||
@ -383,9 +411,18 @@ endif()
|
||||
|
||||
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Blackwell features
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
|
||||
|
||||
# Warnings-as-error exceptions and warning suppressions for Clang builds
|
||||
if (CUTLASS_CLANG_HOST_COMPILE)
|
||||
|
||||
|
||||
set(FLAGS_TO_ADD
|
||||
"-Wno-error=implicit-int-conversion"
|
||||
"-Wno-error=pass-failed"
|
||||
@ -393,20 +430,20 @@ if (CUTLASS_CLANG_HOST_COMPILE)
|
||||
"-Wno-sign-conversion"
|
||||
"-Wno-unused-parameter"
|
||||
)
|
||||
|
||||
|
||||
foreach(FLAG ${FLAGS_TO_ADD})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLAG}")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS "${FLAG}")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS "${FLAG}")
|
||||
endforeach()
|
||||
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v -objtemp) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
|
||||
endif()
|
||||
|
||||
@ -433,6 +470,13 @@ if(UNIX)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
|
||||
endif()
|
||||
|
||||
# Known ctk11.4 issue (fixed later)
|
||||
# Also see https://stackoverflow.com/questions/64523302/cuda-missing-return-statement-at-end-of-non-void-function-in-constexpr-if-fun
|
||||
if (CUDA_VERSION VERSION_LESS 11.5.0)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcudafe "--diag_suppress=implicit_return_from_non_void_function" )
|
||||
message("CUDA_VERSION check pass ${CUDA_VERSION}")
|
||||
endif()
|
||||
|
||||
# Don't leak lineinfo in release builds
|
||||
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt)
|
||||
@ -465,7 +509,7 @@ if (CUTLASS_CLANG_DEVICE_COMPILE)
|
||||
|
||||
link_libraries(nvidia::cudart)
|
||||
link_libraries(nvidia::cuda_driver)
|
||||
|
||||
|
||||
endif()
|
||||
|
||||
#Report CUDA build flags
|
||||
@ -540,7 +584,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real)
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX AND NOT CUTLASS_CLANG_DEVICE_COMPILE)
|
||||
# If we're using clang for device compilation, the ptx is inserted
|
||||
# If we're using clang for device compilation, the ptx is inserted
|
||||
# via another command line option and the `-virtual` flags will cause an error.
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual)
|
||||
endif()
|
||||
@ -669,6 +713,7 @@ target_include_directories(
|
||||
CUTLASS
|
||||
SYSTEM INTERFACE
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include/cccl>
|
||||
)
|
||||
|
||||
install(
|
||||
@ -901,7 +946,7 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
if (NOT __DO_NOT_LOWERCASE_TEST_NAME)
|
||||
string(TOLOWER "${TESTCASE_NAME}" TESTCASE_NAME)
|
||||
endif()
|
||||
|
||||
|
||||
# The following rigmarole is needed to deal with spaces and possible quotes in
|
||||
# command line arguments. The options are passed "by reference" as the actual
|
||||
# variable names holding the real options. We then expand these in a way that
|
||||
@ -958,6 +1003,100 @@ function(cutlass_add_executable_tests NAME TARGET)
|
||||
|
||||
endfunction()
|
||||
|
||||
|
||||
|
||||
function(cutlass_generate_profiler_tests NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs DEPENDS DEPENDEES CUTLASS_PROFILER_EXTRA_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS AND NOT CUTLASS_BUILD_FOR_PROFILER_PERFORMANCE_REGRESSIONS)
|
||||
return()
|
||||
endif()
|
||||
|
||||
install(
|
||||
FILES ${CUTLASS_PROFILER_REGRESSION_LIST_FILE}
|
||||
DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass/
|
||||
RENAME profiler_regressions.csv
|
||||
)
|
||||
|
||||
# Generate cmake test targets for each entry in the testlist csv
|
||||
|
||||
if (NOT EXISTS "${CUTLASS_PROFILER_REGRESSION_LIST_FILE}")
|
||||
message(SEND_ERROR "Profiler unit tests list path is invalid: CUTLASS_PROFILER_REGRESSION_LIST_FILE = ${CUTLASS_PROFILER_REGRESSION_LIST_FILE}")
|
||||
else()
|
||||
message(STATUS "Using ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} to generate profiler-based tests.")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} TEST_LIST)
|
||||
foreach(TEST IN LISTS TEST_LIST)
|
||||
set(TEMP_TEST ${TEST})
|
||||
if ("${TEST}" MATCHES " *cutlass_profiler.*")
|
||||
|
||||
# Generate a flattened name for the test from the test command line.
|
||||
string(REPLACE "," ";" TEST_NAME_LIST ${TEMP_TEST})
|
||||
string(REGEX REPLACE "\\*" "_" TEST_NAME "${TEMP_TEST}")
|
||||
string(REGEX REPLACE "\\\"\\{\\\"\\\"input_params.*\\{.*\\}\\}\\\"" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "\\\"\\{\\\"\\\"input_params.*\\{.*\\}\\}\\\"" "" TEST "${TEST}")
|
||||
string(REGEX REPLACE "," ";" TEST "${TEST}")
|
||||
string(REGEX MATCHALL "[a-zA-Z0-9_=]+" TEST_NAME "${TEST_NAME}")
|
||||
list(FILTER TEST_NAME EXCLUDE REGEX "cutlass_profiler|mode=trace|providers=cutlass")
|
||||
list(JOIN TEST_NAME "_" TEST_NAME)
|
||||
string(REGEX REPLACE "_verification_required=(true|false)" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "_verification_providers=device" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "batch_count=" "batch" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "cluster_m=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "_cluster_n=" "x" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "_cluster_k=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "cluster_m_fallback=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "_cluster_n_fallback=" "x" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "swizzle_size=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "verification_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "warmup_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "profiling_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "sleep_duration=[0-9]+" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REGEX REPLACE "profiling_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "=" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "_error_on_no_match" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "_error_if_nothing_is_profiled" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "kernels" "" TEST_NAME "${TEST_NAME}")
|
||||
string(REPLACE "operation" "" TEST_NAME "${TEST_NAME}")
|
||||
|
||||
if (NOT __DO_NOT_LOWERCASE_TEST_NAME)
|
||||
string(TOLOWER "${TEST_NAME}" TEST_NAME)
|
||||
endif()
|
||||
|
||||
# Munge the test command
|
||||
|
||||
string(REPLACE "cutlass_profiler" "" TEST "${TEST}")
|
||||
set(TEST "${TEST}" ${__CUTLASS_PROFILER_EXTRA_OPTIONS} "--junit-output=${TEST_NAME}")
|
||||
set(TEST_COMMAND_${TEST_NAME} "${TEST}")
|
||||
list(APPEND TEST_COMMAND_VARS ${TEST_NAME})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
cutlass_add_executable_tests(
|
||||
${NAME} cutlass_profiler
|
||||
DEPENDS ${__DEPENDS}
|
||||
DEPENDEES ${__DEPENDEES}
|
||||
TEST_COMMAND_OPTIONS ${TEST_COMMAND_VARS}
|
||||
TEST_COMMAND_OPTIONS_PREFIX TEST_COMMAND_
|
||||
DISABLE_EXECUTABLE_INSTALL_RULE
|
||||
# Uncomment the following line when alloc/dealloc tracking
|
||||
# is fixed for all configurations.
|
||||
# TEST_SETS_SUPPORTED tmem_alloc_tracking
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_ENABLE_TOOLS)
|
||||
add_subdirectory(tools)
|
||||
if (CUTLASS_ENABLE_PROFILER)
|
||||
@ -975,6 +1114,14 @@ if (CUTLASS_ENABLE_TESTS)
|
||||
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
|
||||
add_dependencies(test_all test_unit)
|
||||
endif()
|
||||
if (CUTLASS_ENABLE_PROFILER_UNIT_TESTS AND CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
# Generate profiler based unit test
|
||||
cutlass_generate_profiler_tests(
|
||||
tup
|
||||
DEPENDEES test_unit
|
||||
)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
163
CONTRIBUTORS.md
163
CONTRIBUTORS.md
@ -2,51 +2,104 @@
|
||||
|
||||
[README](./README.md#documentation) > **Contributors**
|
||||
|
||||
# CUTLASS Developers and Contributors
|
||||
# CUTLASS Developers **
|
||||
|
||||
This is the official list of CUTLASS developers and contributors.
|
||||
|
||||
## DEVELOPERS
|
||||
Vijay Thakkar<br />
|
||||
Pradeep Ramani<br />
|
||||
Cris Cecka<br />
|
||||
Aniket Shivam<br />
|
||||
Jack Kosaian<br />
|
||||
Mark Hoemmen<br />
|
||||
Richard Cai<br />
|
||||
Honghao Lu<br />
|
||||
Ethan Yan<br />
|
||||
Haicheng Wu<br />
|
||||
Andrew Kerr<br />
|
||||
Dustyn Blasig<br />
|
||||
Fengqi Qiao<br />
|
||||
Duane Merrill<br />
|
||||
Yujia Zhai<br />
|
||||
Rawn Henry<br />
|
||||
Sergey Klevtsov<br />
|
||||
Shang Zhang<br />
|
||||
Piotr Majcher<br />
|
||||
Paul Springer<br />
|
||||
Markus Hohnerbach<br />
|
||||
Jin Wang<br />
|
||||
Dustyn Blasig<br />
|
||||
Albert Xu<br />
|
||||
Junkai Wu<br />
|
||||
Xiuxia Zhang<br />
|
||||
Haicheng Wu<br />
|
||||
Jack Yang<br />
|
||||
Pradeep Ramani<br />
|
||||
Aditya Atluri<br />
|
||||
Han Li<br />
|
||||
Nick Zhao<br />
|
||||
Ivan Yin<br />
|
||||
Yu-Jung Chen<br />
|
||||
Markus Hoehnerbach<br />
|
||||
Honghao Lu<br />
|
||||
Mihir Awatramani<br />
|
||||
Hao Sheng<br />
|
||||
Zekun Fan<br />
|
||||
Aniket Shivam<br />
|
||||
Siyu Liu<br />
|
||||
Richard Cai<br />
|
||||
Vikas Gupta<br />
|
||||
Ethan Yan<br />
|
||||
Vijay Thakkar<br />
|
||||
Cris Cecka<br />
|
||||
Lawrence Ryan<br />
|
||||
Qun Song<br />
|
||||
Daniel Ricketts<br />
|
||||
dePaul Miller<br />
|
||||
Yuhan Li<br />
|
||||
Saman Ashkiani<br />
|
||||
Jack Chen<br />
|
||||
Shang Zhang<br />
|
||||
Petrick Liu<br />
|
||||
Questa Wang<br />
|
||||
Pramod Shenoy<br />
|
||||
Jack Kosaian<br />
|
||||
Yujia Zhai<br />
|
||||
Zhaodong Chen<br />
|
||||
Manas Sahni<br />
|
||||
Shunfan Shao<br />
|
||||
Fengqi Qiao<br />
|
||||
Serif Yesil<br />
|
||||
Aragorn Guan<br />
|
||||
Heidi He<br />
|
||||
Xiao Song<br />
|
||||
Sergey Klevtsov<br />
|
||||
Jiang Shao<br />
|
||||
Ruqing Xu<br />
|
||||
Mengyu Guo<br />
|
||||
Tao Xie<br />
|
||||
Linfeng Zheng<br />
|
||||
Harrison Barclay<br />
|
||||
Wenfei Tang<br />
|
||||
Diksha Gohlyan<br />
|
||||
Alexander Zhurkevich<br />
|
||||
Siyuan Fu<br />
|
||||
Hua Huang<br />
|
||||
Xiufan Liang<br />
|
||||
Ian Tramble<br />
|
||||
Ali Hassani<br />
|
||||
Shreya Gaur<br />
|
||||
|
||||
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
|
||||
|
||||
|
||||
# CUTE Developers
|
||||
|
||||
## CuTe
|
||||
Cris Cecka<br />
|
||||
Vijay Thakkar<br />
|
||||
|
||||
## CUTLASS Product Manager
|
||||
|
||||
# CUTLASS Product Manager
|
||||
|
||||
Matthew Nicely<br />
|
||||
|
||||
## Former CUTLASS Developers
|
||||
Manish Gupta<br />
|
||||
Naila Farooqui<br />
|
||||
David Tanner<br />
|
||||
Manikandan Ananth<br />
|
||||
Zhaodong Chen<br />
|
||||
Chinmay Talegaonkar<br />
|
||||
|
||||
## CONTRIBUTORS
|
||||
# Former CUTLASS Developers
|
||||
|
||||
Manish Gupta<br />
|
||||
Duane Merrill<br />
|
||||
Piotr Majcher<br />
|
||||
Naila Farooqui<br />
|
||||
Mark Hoemmen<br />
|
||||
Rawn Henry<br />
|
||||
Jin Wang<br />
|
||||
Timmy Liu<br />
|
||||
Manikandan Ananth<br />
|
||||
David Tanner<br />
|
||||
|
||||
|
||||
# Acknowledgements
|
||||
|
||||
Tri Dao<br />
|
||||
Jay Shah<br />
|
||||
Timothy Costa<br />
|
||||
Julien Demouth<br />
|
||||
Brian Fahs<br />
|
||||
@ -56,25 +109,15 @@ Mostafa Hagog<br />
|
||||
Fei Hu<br />
|
||||
Alan Kaatz<br />
|
||||
Tina Li<br />
|
||||
Timmy Liu<br />
|
||||
Wei Liu<br />
|
||||
Tim Martin<br />
|
||||
Duane Merrill<br />
|
||||
Kevin Siu<br />
|
||||
Markus Tavenrath<br />
|
||||
John Tran<br />
|
||||
Vicki Wang<br />
|
||||
Junkai Wu<br />
|
||||
Fung Xie<br />
|
||||
Albert Xu<br />
|
||||
Yang Xu<br />
|
||||
Jack Yang<br />
|
||||
Scott Yokim<br />
|
||||
Xiuxia Zhang<br />
|
||||
Nick Zhao<br />
|
||||
|
||||
## ACKNOWLEDGEMENTS
|
||||
|
||||
Girish Bharambe<br />
|
||||
Luke Durant<br />
|
||||
Carter Edwards<br />
|
||||
@ -85,3 +128,35 @@ Bryce Lelbach<br />
|
||||
Joel McCormack<br />
|
||||
Kyrylo Perelygin<br />
|
||||
Sean Treichler<br />
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
# Publications Using Cutlass
|
||||
|
||||
## 2025
|
||||
|
||||
- ["Comet: Fine-grained Computation-communication Overlapping for Mixture-of-Experts"](https://arxiv.org/abs/2502.19811). Shulai Zhang, Ningxin Zheng, Haibin Lin, Ziheng Jiang, Wenlei Bao, Chengquan Jiang, Qi Hou, Weihao Cui, Size Zheng, Li-Wen Chang, Quan Chen, Xin Liu. _arXiv_, February 2025.
|
||||
|
||||
- ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025.
|
||||
|
||||
## 2024
|
||||
|
||||
- ["DeepSeek-V3 Technical Report"](https://arxiv.org/abs/2412.19437). DeepSeek-AI. _arXiv_, December 2024.
|
||||
|
||||
- ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024.
|
||||
|
||||
- ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024.
|
||||
@ -60,3 +68,35 @@
|
||||
"](https://arxiv.org/abs/2008.13006). Cong Guo, Bo Yang Hsueh, Jingwen Leng, Yuxian Qiu, Yue Guan, Zehuan Wang, Xiaoying Jia, Xipeng Li, Minyi Guo, Yuhao Zhu. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Strassen's Algorithm Reloaded on GPUs"](https://dl.acm.org/doi/10.1145/3372419). Jianyu Huang, Chenhan D. Yu, Robert A. van de Geijn. _ACM Transactions on Mathematical Software_, March 2020.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
259
README.md
259
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 3.7.0
|
||||
# CUTLASS 3.9.0
|
||||
|
||||
_CUTLASS 3.7.0 - January 2025_
|
||||
_CUTLASS 3.9.0 - April 2025_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -16,63 +16,85 @@ as building blocks within custom kernels and applications.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32),
|
||||
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16,
|
||||
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
8b floating point types (e5m2 and e4m3),
|
||||
block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8),
|
||||
narrow integer types (4 and 8b signed and unsigned integers),
|
||||
and binary 1b data types (where architectures allow for the
|
||||
native support of such data types).
|
||||
CUTLASS demonstrates optimal matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, Ampere, and Hopper architectures.
|
||||
NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
|
||||
|
||||
See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via
|
||||
the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution
|
||||
operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
See the [functionality listing](./media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
See the [Quick Start Guide](./media/docs/cpp/quickstart.md) to get started quickly.
|
||||
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
|
||||
See the [functionality docs](./media/docs/cpp/functionality.md) for a more comprehensive
|
||||
list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU
|
||||
architecture.
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
# What's New in CUTLASS 3.9
|
||||
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
|
||||
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](./media/docs/cute/00_quickstart.md).
|
||||
* Support for Blackwell SM120 kernels for GeForce GPUs in CUTLASS 3.x API:
|
||||
- Collective mainloops that target for:
|
||||
* [Blockscaled datatypes with support for dense GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp)
|
||||
* [Blockscaled datatypes with support for sparse GEMM](./include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp)
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell SM120 epilogue](./include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp) and [full set of EVT fusions](./include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM120 architecture:
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu).
|
||||
- [Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor with scale factor generation](./examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu).
|
||||
- [Blockscaled GEMM with mixed input datatype (MXFP8 and MXFP6) and BF16 output tensor](./examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu).
|
||||
- [Grouped GEMM with nvfp4 datatype](./examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with mxfp8 input datatype and BF16 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu).
|
||||
- [Sparse Blockscaled GEMM with NVFP4 input datatype and NVFP4 output tensor](./examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu).
|
||||
* Set of unit tests that demonstrate the usage of both [sparse](./test/unit/gemm/device/sm120_blockscaled_sparse_tensorop_gemm/) and [dense](./test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/) Blackwell SM120 blockscaled GEMM.
|
||||
* Support for Blackwell SM100 Sparse kernels:
|
||||
- Collective mainloop that target for
|
||||
* [SM100 Sparse GEMM](./include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp)
|
||||
* Set of example that demonstrate the usage of the 3.x API for targeting Blackwell SM100 Sparse GEMM:
|
||||
- [Sparse GEMM](./examples/83_blackwell_sparse_gemm/83_blackwell_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with NVFP4 input data type](./examples/84_blackwell_narrow_precision_sparse_gemm/84a_blackwell_nvfp4_bf16_sparse_gemm.cu)
|
||||
- [Blockscaled Sparse GEMM with mixed input data type (MXFP8 and MXFP4)](./examples/84_blackwell_narrow_precision_sparse_gemm/84b_blackwell_mixed_mxfp8_bf16_sparse_gemm.cu)
|
||||
* Set of unit tests that demonstrate the usage of [sparse](./test/unit/gemm/device/sm100_sparse_tensorop_gemm) and [blockscaled sparse](./test/unit/gemm/device/sm100_blockscaled_sparse_tensorop_gemm) Blackwell SM100 GEMM.
|
||||
* A new Multi-head Latent Attention (MLA) for SM100 Blackwell architecture in CUTLASS [example](./examples/77_blackwell_fmha/) covers the flashMLA-like weight-absorbed decoding use-case.
|
||||
* A new FMHA Backward kernel for SM100 Blackwell architecture extends CUTLASS [example](./examples/77_blackwell_fmha/) to show how the five backward pass MMAs can be fused into a single kernel to achieve high performance.
|
||||
* A new [distributed GEMM example](./examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu) for SM100 Blackwell architecture.
|
||||
* Enhancement and new support of block-wise and group-wise GEMM for Hopper and Blackwell architectures:
|
||||
- Enhancement of [blockwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) for Hopper architecture.
|
||||
- Enhancement of [groupwise GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu) for Hopper architecture.
|
||||
- Support for [grouped GEMM with blockwise and groupwise scaling](./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/) for Hopper architecture.
|
||||
- Support for [grouped-wise GEMM](./tools/profiler/src/blockwise_gemm_operation_profiler.cu) in CUTLASS profiler.
|
||||
- Support for [blockwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu) for Blackwell architecture.
|
||||
- Support for [groupwise GEMM](./examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu) for Blackwell architecture.
|
||||
- Support for [grouped GEMM with blockwise](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_blockwise.cu) and [groupwise scaling](./examples/81_blackwell_gemm_blockwise/81_blackwell_grouped_gemm_groupwise.cu) for Blackwell architecture.
|
||||
* Added support for enhanced kernel performance search (auto-tuning) in CUTLASS profiler:
|
||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||
* Support `void` as the D element in sm100 kernel epilogues.
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
# What's New in CUTLASS 3.7
|
||||
|
||||
CUTLASS 3.7.0 is an update to CUTLASS adding:
|
||||
|
||||
- A new [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) where the operands and block scaling tensor are staged via shared memory.
|
||||
- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is an experimental pipelined Tensor Parallelism implementation utilizing existing CUTLASS kernels and CUDA runtime features, which can hide the most of communication behind computation.
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
Starting from CUTLASS 3.0, CUTLASS removed support for the following:
|
||||
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- C++ language versions less than 17.
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit peak performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows the continual CUTLASS performance improvements
|
||||
they exhibit nearly optimal utilization of peak theoretical throughput. The figure below
|
||||
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
|
||||
on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg></p>
|
||||
|
||||
The two figures below show the continual CUTLASS performance improvements
|
||||
on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since
|
||||
CUTLASS 3.1.
|
||||
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
@ -80,20 +102,45 @@ Tensor Core operations are implemented using CUDA's
|
||||
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
|
||||
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
|
||||
|
||||
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
|
||||
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
|
||||
as shown in the above figure. Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
# CuTe
|
||||
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for
|
||||
defining and operating on hierarchically multidimensional layouts of threads and data.
|
||||
CuTe provides `Layout` and `Tensor` objects that compactly package the type,
|
||||
shape, memory space, and layout of data, while performing the complicated indexing for the user.
|
||||
This lets programmers focus on the logical descriptions of their algorithms while
|
||||
CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design,
|
||||
implement, and modify all dense linear algebra operations.
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts
|
||||
which can be composed with data arrays to represent tensors.
|
||||
The representation of layouts is powerful enough to represent nearly
|
||||
everything we need to implement efficient dense linear algebra.
|
||||
Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||
This greatly simplifies the design and improves code composability and readability.
|
||||
More documentation specific to CuTe can be found in its
|
||||
[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md).
|
||||
|
||||
# Compatibility
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta (compute capability 7.0)
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.4 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
|
||||
performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions.
|
||||
|
||||
## Operating Systems
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
@ -101,73 +148,101 @@ We have tested the following environments.
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
| Ubuntu 22.04 | Clang 10.0.0 |
|
||||
| Ubuntu 22.04 | Clang 14.0.6 |
|
||||
| Ubuntu 22.04 | Clang 17.0.6 |
|
||||
| Windows 10.0 | Visual Studio 2019 v16.11.27 |
|
||||
|
||||
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
|
||||
|
||||
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
## Hardware
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**|
|
||||
|---|---|---|
|
||||
|NVIDIA V100 Tensor Core GPU |7.0|11.4|
|
||||
|NVIDIA TitanV |7.0|11.4|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4|
|
||||
|NVIDIA GeForce RTX 20x0 series |7.5|11.4|
|
||||
|NVIDIA T4 |7.5|11.4|
|
||||
|NVIDIA A100 Tensor Core GPU |8.0|11.4|
|
||||
|NVIDIA A10 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 3090 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 4090 |8.9|11.8|
|
||||
|NVIDIA GeForce RTX 30x0 series |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 40x0 series |8.9|11.8|
|
||||
|NVIDIA L40 |8.9|11.8|
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|
||||
|NVIDIA GeForce RTX 50x0 series |10.0|12.8|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
In general, PTX code generated for one target architecture can be run on future architectures
|
||||
(i.e., it is forward compatible).
|
||||
However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose
|
||||
PTX does not have forward compatibility guarantees.
|
||||
Several Hopper and Blackwell PTX instructions fall under this category of
|
||||
architecture-accelerated features, and thus require a `sm_90a` or `sm100a` target architecture
|
||||
(note the "a" appended). For more details on this and other architecture-accelerated instructions,
|
||||
please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CUDA Toolkit 12 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag
|
||||
`CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100,
|
||||
users are required to build CUTLASS with `90a` as the target architecture.
|
||||
If a user accidentally builds a kernel which uses SM90a features
|
||||
(e.g. Hopper Tensor Core Instructions), using the SM90 target
|
||||
(note the lack of "a"), with either CUDA Toolkit 12 or 11.8,
|
||||
the kernel is expected to fail with a runtime error.
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
```
|
||||
Or
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
```
|
||||
|
||||
Please refer to the [functionality documentation](./media/docs/functionality.md) for details on which kernels require which target architectures.
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
Please refer to the [functionality documentation](./media/docs/cpp/functionality.md)
|
||||
for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](./media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](./media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development
|
||||
- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS
|
||||
- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
# Resources
|
||||
We have also described the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
@ -176,7 +251,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
|
||||
paths.
|
||||
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
@ -221,7 +296,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below.
|
||||
[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
@ -295,7 +370,7 @@ tools/
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
@ -511,9 +586,9 @@ reference_device: Passed
|
||||
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md)
|
||||
- [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md)
|
||||
|
||||
|
||||
# About
|
||||
|
||||
92
customConfigs.cmake
Normal file
92
customConfigs.cmake
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Profiler based functional testing
|
||||
set(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS OFF CACHE BOOL "Utilize profiler-based functional regressions")
|
||||
set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "Profiler functional regression test level")
|
||||
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
function(cutlass_generate_kernel_filter_and_testlists_files)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs TEST_SET_NAME)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR}
|
||||
${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py
|
||||
--generator-target=${__TEST_SET_NAME}
|
||||
--cuda-version=${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}
|
||||
--architectures=${CUTLASS_NVCC_ARCHS}
|
||||
--kernels=\*
|
||||
--disable-cutlass-package-imports
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
|
||||
RESULT_VARIABLE cutlass_FILTER_GENERATION_RESULT
|
||||
OUTPUT_VARIABLE cutlass_FILTER_GENERATION_OUTPUT
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log
|
||||
)
|
||||
|
||||
if(NOT cutlass_FILTER_GENERATION_RESULT EQUAL 0)
|
||||
message(FATAL_ERROR "Error generating kernel filters and testlists files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
|
||||
set(PROFILER_ARCH_LIST 100a 101a 120a)
|
||||
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
|
||||
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
|
||||
message(FATAL_ERROR "Only SM100a/101a/120a compute capability is supported with profiler-based unit tests")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 0)
|
||||
|
||||
message(STATUS "Building for L0 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l0)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
elseif (CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 1)
|
||||
|
||||
message(STATUS "Building for L1 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l1)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
addressable memory, and then store it back into addressable memory.
|
||||
|
||||
TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to
|
||||
and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines
|
||||
and from addressable memory. The PredicatedTileIterator accepts a ThreadMap type, which defines
|
||||
the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined
|
||||
thread mappings to be specified.
|
||||
|
||||
@ -124,7 +124,7 @@ __global__ void copy(
|
||||
|
||||
cudaError_t TestTileIterator(int M, int K) {
|
||||
|
||||
// For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects
|
||||
// For this example, we chose a <64, 4> tile shape. The PredicatedTileIterator expects
|
||||
// PitchLinearShape and PitchLinear layout.
|
||||
using Shape = cutlass::layout::PitchLinearShape<64, 4>;
|
||||
using Layout = cutlass::layout::PitchLinear;
|
||||
@ -136,7 +136,7 @@ cudaError_t TestTileIterator(int M, int K) {
|
||||
// dimension then along the strided dimension.
|
||||
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<Shape, kThreads>;
|
||||
|
||||
// Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types
|
||||
// Define the PredicatedTileIterator, using TileShape, Element, Layout, and ThreadMap types
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
Shape, Element, Layout, 1, ThreadMap>;
|
||||
|
||||
|
||||
@ -115,4 +115,3 @@ SPDX-License-Identifier: BSD-3-Clause
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -2,3 +2,35 @@
|
||||
|
||||
This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface.
|
||||
For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -165,3 +165,35 @@ Example 7: GELU
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
|
||||
```
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -483,12 +483,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -540,6 +540,15 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (__CUDACC_VER_MAJOR__ < 12 || props.major != 9 || props.minor != 0) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture "
|
||||
<< "(compute capability 90) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -356,6 +356,15 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (__CUDACC_VER_MAJOR__ < 12 || props.major != 9 || props.minor != 0) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture "
|
||||
<< "(compute capability 90) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -626,6 +626,13 @@ int main(int argc, const char ** argv) {
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -750,6 +750,13 @@ int main(int argc, char const **argv)
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -566,12 +566,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -103,11 +103,10 @@
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
using ValueShuffle = Layout<Shape<_2,_4>, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7]
|
||||
int constexpr NumShuffleAtoms = 1;
|
||||
using MmaAtomShape = Layout<Shape<_1,Int<NumShuffleAtoms>>>;
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale;
|
||||
@ -403,7 +402,7 @@ struct Options : MixedDtypeOptions{
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -430,7 +429,7 @@ void initialize(Options const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
@ -438,14 +437,15 @@ void initialize(Options const& options) {
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
|
||||
|
||||
if (options.shuffle) {
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
|
||||
layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
@ -613,12 +613,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -107,11 +107,10 @@
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
#include "packed_scale.hpp"
|
||||
#include "reorder_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,int>, StrideB>{}));
|
||||
|
||||
using ElementScale = MmaType;
|
||||
using ElementZero = ElementScale; // only for verify
|
||||
@ -319,7 +318,7 @@ struct Options : MixedDtypeOptions {
|
||||
void initialize(Options const& options) {
|
||||
|
||||
auto shape_B = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -348,11 +347,11 @@ void initialize(Options const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
unify_quant_encoding(block_B, block_B_modified);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_packed_scale(block_scale, block_scale_packed);
|
||||
cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size());
|
||||
initialize_zero(block_zero, options);
|
||||
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
@ -360,12 +359,13 @@ void initialize(Options const& options) {
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
|
||||
|
||||
if (options.shuffle) {
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
|
||||
layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered);
|
||||
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered);
|
||||
@ -518,12 +518,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -100,6 +100,7 @@
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "mixed_dtype_utils.hpp"
|
||||
@ -287,7 +288,7 @@ cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::Ele
|
||||
void initialize(MixedDtypeOptions const& options) {
|
||||
|
||||
auto shape_b = cute::make_shape(options.n, options.k, options.l);
|
||||
int const scale_k = (options.k + options.g - 1) / options.g;
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.g);
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b);
|
||||
// Reverse stride here due to swap and transpose
|
||||
@ -312,7 +313,7 @@ void initialize(MixedDtypeOptions const& options) {
|
||||
block_zero.reset(scale_k * options.l * options.n);
|
||||
|
||||
initialize_tensor(block_A, seed + 2022);
|
||||
initialize_quant_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_B, seed + 2021);
|
||||
initialize_tensor(block_C, seed + 2020);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
@ -322,9 +323,10 @@ void initialize(MixedDtypeOptions const& options) {
|
||||
auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l);
|
||||
stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l));
|
||||
auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref);
|
||||
auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref);
|
||||
|
||||
dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g);
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
@ -483,12 +485,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -7,14 +7,18 @@ When relying on `KernelScheduleAuto`, the main loop supporting different A and B
|
||||
|
||||
This first version only supports mixed type GEMMs using TMA.
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type.
|
||||
|
||||
The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array<ElementScale, 8>` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now.
|
||||
|
||||
Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details.
|
||||
|
||||
We are currently optimizing the following cases:
|
||||
1. Memory bound cases for all types
|
||||
2. `fp8 x {int2, uint2}` case
|
||||
|
||||
## Limitations
|
||||
|
||||
@ -37,3 +41,35 @@ We are currently optimizing the following cases:
|
||||
* Optimizations for memory bound cases.
|
||||
|
||||
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -60,8 +60,8 @@ struct MixedDtypeOptions {
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int iterations = 1000;
|
||||
int warmup = 1000;
|
||||
int iterations = 100;
|
||||
int warmup = 10;
|
||||
int mode = 1;
|
||||
int m = 5120, n = 4096, k = 4096;
|
||||
int g = 128;
|
||||
@ -151,16 +151,16 @@ void mixed_dtype_profiling(
|
||||
runtimes.reserve(options.iterations);
|
||||
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
cudaEventDestroy(start);
|
||||
@ -208,42 +208,22 @@ bool initialize_tensor(
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
bool initialize_quant_tensor(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
|
||||
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Element>
|
||||
bool initialize_scale(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(1.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
}
|
||||
else {
|
||||
// If no scales, initialize with 1 so we can use the same kernel to dequantize the data
|
||||
float scope_max = 1.0f, scope_min = 1.0f;
|
||||
if (options.mode != MixedDtypeGemmMode::ConvertOnly) {
|
||||
float elt_max_f = float(cutlass::platform::numeric_limits<Element>::max());
|
||||
const float max_dequant_val = 4.f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
|
||||
float scope_max(max_dequant_val / elt_max_f);
|
||||
float scope_min(min_dequant_val / elt_max_f);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
scope_max = 2.f;
|
||||
scope_min = 0.1f;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -253,139 +233,14 @@ bool initialize_zero(
|
||||
MixedDtypeOptions const& options,
|
||||
uint64_t seed = 2023) {
|
||||
|
||||
// If no bias, initialize with 0 so we can use the same kernel to dequantize the data
|
||||
float scope_max = 0.0f, scope_min = 0.0f;
|
||||
if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
|
||||
} else {
|
||||
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
|
||||
std::vector<Element> stage(block.size(), Element(0.0f));
|
||||
block.copy_from_host(stage.data());
|
||||
scope_max = 2.0f;
|
||||
scope_min = -2.0f;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Dequantize the weights for verification
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleBroadCastLayout,
|
||||
class ThrLayout>
|
||||
__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleBroadCastLayout const broadcasted_scale_layout,
|
||||
ThrLayout thr_layout) {
|
||||
using namespace cute;
|
||||
|
||||
// Represent the full tensors to gmem elements.
|
||||
// These are expected to have shape [MN, K, L]
|
||||
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
|
||||
auto init_quantized_iterator = [&]() {
|
||||
if constexpr (cute::sizeof_bits_v<QuantizedElement> >= 8) {
|
||||
return cute::make_gmem_ptr(q_buffer);
|
||||
} else {
|
||||
return cute::subbyte_iterator<const QuantizedElement>(q_buffer);
|
||||
}
|
||||
};
|
||||
cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout);
|
||||
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
||||
// It is expected that K % G == 0
|
||||
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
||||
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
||||
|
||||
// Assign 1 thread per element in the thread block
|
||||
auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
||||
auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
|
||||
|
||||
// Tile across the block
|
||||
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
|
||||
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
||||
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
||||
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
|
||||
|
||||
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
|
||||
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
|
||||
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
|
||||
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
|
||||
|
||||
// Make a fragment of registers to hold gmem loads
|
||||
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
||||
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
|
||||
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
|
||||
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
||||
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
|
||||
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
|
||||
|
||||
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
|
||||
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
|
||||
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
|
||||
|
||||
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
|
||||
|
||||
for (int ii = 0; ii < num_iters; ++ii) {
|
||||
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
|
||||
if (thread_offset < cute::size<0>(operand_layout)) {
|
||||
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
||||
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
||||
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
||||
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{});
|
||||
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{});
|
||||
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
||||
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleLayout>
|
||||
void dequantize_weight(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleLayout const scale_layout,
|
||||
int const group_size) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
constexpr int tpb = 128;
|
||||
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
||||
|
||||
const auto num_rows = get<0>(shape(operand_layout));
|
||||
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
||||
|
||||
if (num_rows != size<0>(scale_layout)) {
|
||||
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
||||
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const auto scale_stride0 = get<0>(stride(scale_layout));
|
||||
const auto scale_stride1 = get<1>(stride(scale_layout));
|
||||
const auto scale_stride2 = get<2>(stride(scale_layout));
|
||||
|
||||
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
||||
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
||||
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
||||
|
||||
const auto blocks_x = gemm_k;
|
||||
const auto blocks_y = batches;
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y, 1);
|
||||
dequantize_weight_kernel<<<blocks, tpb>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
@ -1,210 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
template<typename T>
|
||||
class packed_scale_t {
|
||||
public:
|
||||
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
||||
cute::is_same_v<T, cutlass::uint8_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
||||
"only 8 bit arithmetic types are supported.");
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit packed_scale_t(T val) {
|
||||
if constexpr (!cute::is_unsigned_v<T>) {
|
||||
// Only pack negative values. The positive values are generated in flight in the mainloop.
|
||||
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
|
||||
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
|
||||
}
|
||||
else {
|
||||
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
|
||||
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
|
||||
}
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
packed_scale_t() = default;
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit operator float() const {
|
||||
return float(get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(packed_scale_t const& rhs) const {
|
||||
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(packed_scale_t const& rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() + rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() - rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() * rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() / rhs.get());
|
||||
}
|
||||
|
||||
private:
|
||||
using Storage = uint32_t;
|
||||
using Stage = uint8_t;
|
||||
|
||||
Storage storage[2] {};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Storage pack4(T c1, T c2, T c3, T c4) {
|
||||
Storage result = 0;
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
|
||||
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
|
||||
return result;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get() const {
|
||||
auto stage = static_cast<Stage>(storage[0] >> 8);
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get(int idx) const {
|
||||
Stage stage;
|
||||
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
|
||||
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Helpers to initialize scale lookup table
|
||||
|
||||
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
||||
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
||||
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
||||
bool unify_quant_encoding(
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::int4b_t>& block_out) {
|
||||
|
||||
using StorageType = cutlass::int4b_t::Storage;
|
||||
|
||||
if (block_in.size() != block_out.size()) {
|
||||
std::cerr << "block_in and block_out must have same size.\n";
|
||||
return false;
|
||||
}
|
||||
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
|
||||
std::vector<StorageType> data(block_in.size() / pack);
|
||||
cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack);
|
||||
|
||||
for (auto&& d : data) {
|
||||
StorageType out = 0;
|
||||
StorageType mask = 0x0f;
|
||||
for (int i = 0; i < pack; ++i) {
|
||||
cutlass::int4b_t curr;
|
||||
curr.storage = (d >> (i * 4)) & 0x0f;
|
||||
switch (curr) {
|
||||
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
|
||||
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
|
||||
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
|
||||
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
|
||||
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
|
||||
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
|
||||
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
|
||||
default: break;
|
||||
}
|
||||
out |= (curr.storage << (4 * i)) & mask;
|
||||
mask <<= 4;
|
||||
}
|
||||
d = out;
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ElementScale>
|
||||
bool initialize_packed_scale(
|
||||
cutlass::DeviceAllocation<ElementScale> const& block_in,
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8> > & block_out) {
|
||||
|
||||
std::vector<ElementScale> data_in(block_in.size());
|
||||
std::vector<cutlass::Array<ElementScale, 8> > data_out(block_in.size());
|
||||
try {
|
||||
block_in.copy_to_host(data_in.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < block_in.size(); ++i)
|
||||
{
|
||||
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
||||
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
||||
// std::cout << data_in[i] << ":" << std::hex << static_cast<uint16_t>(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast<uint16_t>((-data_in[i]).storage) << std::endl;
|
||||
}
|
||||
try {
|
||||
block_out.copy_from_host(data_out.data());
|
||||
} catch (cutlass::cuda_exception const& e)
|
||||
{
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -1,162 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/mma_sm90.hpp"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
// Given a type of MMA instruction, compute a memory reordering atom that places all values
|
||||
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
|
||||
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
|
||||
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
|
||||
// In addition, we can reorder the values across several MMA instructions to get even wider
|
||||
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
|
||||
// more optimal conversion instruction sequences (ValLayout parameter).
|
||||
template<class ElementMma,
|
||||
class AtomLayout = cute::Layout<cute::_1>,
|
||||
class ValLayout = cute::Layout<cute::_1>>
|
||||
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
|
||||
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
|
||||
|
||||
// 1. Choose an MMA atom to access TV layout and MN shape
|
||||
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
|
||||
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
|
||||
using MmaTraits = MMA_Traits<MmaAtom>;
|
||||
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
|
||||
auto tv_layout_mma = typename MmaTraits::ALayout{};
|
||||
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
|
||||
|
||||
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
|
||||
// Note: this assumes A is partitioned between warps along M mode
|
||||
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
|
||||
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
|
||||
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
|
||||
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
|
||||
|
||||
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
|
||||
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
|
||||
|
||||
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
|
||||
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
|
||||
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
|
||||
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
|
||||
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
|
||||
|
||||
return layout_atom;
|
||||
}
|
||||
|
||||
template<class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
|
||||
__global__ void reorder_tensor_kernel(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D,
|
||||
TiledCopy tiled_copy)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
|
||||
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
|
||||
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
|
||||
Tensor tS = thread_copy.partition_S(gS);
|
||||
Tensor tD = thread_copy.partition_D(gD);
|
||||
|
||||
copy(tiled_copy, tS, tD);
|
||||
}
|
||||
|
||||
template<class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
|
||||
|
||||
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
|
||||
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
|
||||
auto has_major_mode = [](auto s) {
|
||||
return any_of(s, [](auto a){ return is_constant<1, decltype(a)>{}; });
|
||||
};
|
||||
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
|
||||
"Could not find stride-1 mode in destination layout");
|
||||
constexpr int N = shape_div(Int<8>{}, sizeof_bits<T>{});
|
||||
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
|
||||
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
|
||||
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
|
||||
|
||||
// Make a tiled copy with a simple row-major thread order and above layout
|
||||
int constexpr NumThreads = 128;
|
||||
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
|
||||
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
|
||||
|
||||
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
|
||||
using TileShape = Shape<_16>;
|
||||
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
|
||||
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
|
||||
|
||||
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T const* src,
|
||||
LayoutSrc const& layout_src,
|
||||
T * dst,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
|
||||
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template<class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T * data,
|
||||
LayoutSrc const& layout_src,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
cutlass::DeviceAllocation<T> temp(size(layout_src));
|
||||
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
||||
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
|
||||
}
|
||||
@ -513,12 +513,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -731,12 +731,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -768,16 +768,26 @@ int main(int argc, char const** argv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4) ||
|
||||
(props.major != 8 && props.minor != 9)) {
|
||||
bool satisfied;
|
||||
if (props.major < 10) {
|
||||
// Pre-Blackwell
|
||||
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4);
|
||||
satisfied &= (props.major > 8) || (props.major == 8 && props.minor == 9);
|
||||
}
|
||||
else {
|
||||
satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8);
|
||||
}
|
||||
|
||||
if (!satisfied) {
|
||||
//
|
||||
// This example requires an NVIDIA Ada-architecture GPU.
|
||||
// This example requires an NVIDIA GPU with compute capability 8.9 or greater.
|
||||
//
|
||||
|
||||
std::cout
|
||||
<< "CUTLASS's FP8 SM89 example requires a GPU of NVIDIA's Ada architecture "
|
||||
<< "and CUDA toolkit version 12.4 or later.\n";
|
||||
<< "CUTLASS's FP8 SM89 example requires an NVIDIA GPU with compute capability 8.9 or greater "
|
||||
<< "and CUDA toolkit version 12.4 or later"
|
||||
<< " (12.8 or later needed for SM100+)"
|
||||
<< std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -207,3 +207,35 @@ With this in mind, this example kernel has the following limitations:
|
||||
- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s
|
||||
- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape
|
||||
- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -37,8 +37,11 @@
|
||||
|
||||
Those assumptions are as:
|
||||
1. Fusion is over the N dimension.
|
||||
2. Top-K is either 2 or 4 elements, and the value is static (meaning two kernels have to be
|
||||
compiled to support both.)
|
||||
2. Top-K value is static (meaning multiple kernels have to be compiled to support
|
||||
different values.)
|
||||
* NOTE: Only K=2 and K=4 cases are performance-optimized and enabled by default.
|
||||
There is also a generic sort that supports all K values greater than 1, but it can lead to serious performance implications to the underlying kernel.
|
||||
If necessary, users can simply remove the K==2 || K ==4 assertion under cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp, and the generic sort will automatically be used for all other Ks.
|
||||
3. The GEMM tile shape along N is greater than or equal to problem size
|
||||
along N.
|
||||
|
||||
@ -501,12 +504,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -570,12 +570,15 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -469,12 +469,13 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -26,11 +26,13 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
set(TEST_PREFETCH_CASE --m=8192 --n=64 --k=8192 --iterations=0)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
63_hopper_gemm_with_weight_prefetch
|
||||
63_hopper_gemm_with_weight_prefetch.cu
|
||||
)
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_PREFETCH_CASE
|
||||
)
|
||||
|
||||
target_include_directories(63_hopper_gemm_with_weight_prefetch PUBLIC .)
|
||||
|
||||
@ -74,9 +74,40 @@ echo "Overlap ratio of 0.8, prefetch ratio of 0.7"
|
||||
However, note that the example still runs a single GEMM, and most of the performance improvement
|
||||
is expected in end to end applications.
|
||||
|
||||
|
||||
## Limitations
|
||||
* The parameter defaults are typically not good choices, especially `prefetch_ratio`.
|
||||
When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a
|
||||
memory barrier before issuing every single TMA load, and in many cases this will slow down
|
||||
prefetching to the point of being almost ineffective.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -362,11 +362,11 @@ public:
|
||||
using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier;
|
||||
auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z,
|
||||
/*reserved_named_barriers_*/ 14);
|
||||
/*id*/ 0);
|
||||
// Prefetcher warp doesn't arrive on this barrier.
|
||||
auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier(
|
||||
blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp,
|
||||
/*reserved_named_barriers_*/ 15);
|
||||
/*id*/ 1);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) {
|
||||
__syncwarp();
|
||||
|
||||
@ -120,8 +120,7 @@
|
||||
#include "helper.h"
|
||||
|
||||
// Distributed GEMM helpers
|
||||
#include "util/benchmark.h"
|
||||
#include "util/device_copy.h"
|
||||
#include "dist_gemm_helpers.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -133,7 +132,8 @@ using namespace cute;
|
||||
using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
// Choices:
|
||||
@ -344,7 +344,8 @@ struct Result {
|
||||
|
||||
};
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -832,10 +833,10 @@ int main(int argc, char const **args) {
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major < 9) {
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater)." << std::endl;
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture "
|
||||
<< "(compute capability 90)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -62,3 +62,40 @@ procedure is the same, simply modify the following line in the example:
|
||||
```cpp
|
||||
using TP = _8;
|
||||
```
|
||||
|
||||
## References
|
||||
* [Distributed GEMM Blog](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b)
|
||||
* [Distributed GEMM Talk on CUDA Mode](https://www.youtube.com/watch?v=NHRTCQBZokg)
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -17,6 +17,8 @@ Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit ar
|
||||
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
|
||||
CUDA graph APIs.
|
||||
|
||||
The minimum CUDA driver version for running this example is [560.28.03](https://docs.nvidia.com/cuda/archive/12.6.0/cuda-toolkit-release-notes/index.html#id5).
|
||||
|
||||
### Hardware / driver settings
|
||||
|
||||
This example requires Hopper GPUs with NVLink network.
|
||||
@ -84,3 +86,35 @@ GPU5 OK OK OK OK OK X OK OK
|
||||
GPU6 OK OK OK OK OK OK X OK
|
||||
GPU7 OK OK OK OK OK OK OK X
|
||||
```
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -75,11 +75,11 @@
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
#include "reference/host/gemm_with_blockwise_scaling.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@ -100,7 +100,7 @@ using LayoutB = cutlass::layout::ColumnMajor; // L
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
@ -123,7 +123,13 @@ using ArchTag = cutlass::arch::Sm90; // T
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
|
||||
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(TileShape{}));
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
@ -143,8 +149,8 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
|
||||
|
||||
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
@ -190,20 +196,22 @@ StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
StrideAux stride_aux;
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
uint64_t seed;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
uint32_t mma_promotion_interval;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutA> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutB> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
|
||||
@ -251,117 +259,116 @@ struct Result
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
return true;
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, bits_input);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
// Find Block Scaling tensor shapes based on problem shape and TileShape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_aux = stride_D;
|
||||
|
||||
// Layout SFA and SFB represent logically broadcasting data in CuTe.
|
||||
// E.g., if Layout SFA has shape ((ScaleGranularityM, M / ScaleGranularityM), (ScaleGraunularityK, K / ScaleGranularityK))
|
||||
// and strides ((0, 1), (0, M / ScaleGraunuarlityM)), then each collection of ScaleGranularityM x ScaleGranularityK
|
||||
// indecies in the tensor map to the same offset.
|
||||
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto blockscale_a_coord = cutlass::make_Coord(blockscale_m * options.l, blockscale_k);
|
||||
auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l);
|
||||
auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
blockscale_tensor_A.resize(blockscale_a_coord);
|
||||
@ -398,6 +405,10 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
@ -434,14 +445,18 @@ void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
|
||||
if (IsAuxFp8 && options.save_aux && options.save_amax) {
|
||||
abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_aux.sync_device();
|
||||
reference_abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -457,7 +472,9 @@ typename Gemm::Arguments args_from_options(const Options<RasterOrderOptions> &op
|
||||
stride_B,
|
||||
mma_promotion_interval,
|
||||
blockscale_tensor_A.device_data(),
|
||||
blockscale_tensor_B.device_data()
|
||||
layout_SFA,
|
||||
blockscale_tensor_B.device_data(),
|
||||
layout_SFB
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
@ -511,13 +528,6 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Block scaling tensors shapes based CTA Block (TileShape) and GEMM Problem shape
|
||||
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
|
||||
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
@ -550,28 +560,18 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
)
|
||||
);
|
||||
|
||||
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_m, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(blockscale_n, blockscale_k, options.l),
|
||||
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
|
||||
)
|
||||
);
|
||||
auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator,
|
||||
decltype(A), decltype(B),
|
||||
decltype(blockscale_A), decltype(blockscale_B),
|
||||
TileShape> mainloop_params{
|
||||
A, B, // Operand Tensors
|
||||
blockscale_A, blockscale_B // Blockwise scaling Tensors
|
||||
};
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
@ -604,29 +604,40 @@ bool verify(const Options<RasterOrderOptions> &options) {
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
bool passed = true;
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
|
||||
if (false) {
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
#if 0
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
#endif
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
}
|
||||
|
||||
@ -662,20 +673,22 @@ int run(Options<RasterOrderOptions> &options)
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
if (options.verify) {
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// if (!result.passed) {
|
||||
// exit(-1);
|
||||
// }
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
else {
|
||||
result.passed = true;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
if (iter == options.warmup)
|
||||
timer.start();
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
@ -700,7 +713,7 @@ int run(Options<RasterOrderOptions> &options)
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return result.passed;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
@ -746,7 +759,9 @@ int main(int argc, char const **args) {
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
bool passed = run<Gemm>(options);
|
||||
if (!passed)
|
||||
return -1;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -0,0 +1,806 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster.
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for
|
||||
A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor.
|
||||
5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
Examples:
|
||||
$ ./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \
|
||||
--m=2816 --n=3072 --k=16384 \
|
||||
--save_aux=false --save_amax=false \
|
||||
--device_scale=false --raster=h --swizzle=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Auxiliary matrix configuration and other fusion types
|
||||
using ElementAux = ElementC;
|
||||
using LayoutAux = LayoutC;
|
||||
using ElementAmax = float;
|
||||
using ElementBias = float;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 128;
|
||||
constexpr int ScaleGranularityK = 128;
|
||||
|
||||
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, cute::tuple<LayoutA, LayoutSFA>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB, LayoutSFB>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::StreamKScheduler
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideAux = StrideD;
|
||||
|
||||
constexpr bool IsDFp8 =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsAuxFp8 =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
StrideAux stride_aux;
|
||||
LayoutSFA layout_SFA;
|
||||
LayoutSFB layout_SFB;
|
||||
uint64_t seed;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
uint32_t mma_promotion_interval;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_A;
|
||||
cutlass::HostTensor<ElementBlockScale, LayoutScalar> blockscale_tensor_B;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
|
||||
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_B;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_C;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_D;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, bits_input);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Helper to initialize a block of device data (scale_tensors)
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_scale_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
|
||||
scope_min = -1;
|
||||
scope_max = 1;
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Not implementated.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options<RasterOrderOptions> &options) {
|
||||
|
||||
assert(options.m % ScaleGranularityM == 0);
|
||||
assert(options.n % ScaleGranularityN == 0);
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_aux = stride_D;
|
||||
layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l));
|
||||
layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l));
|
||||
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
auto groupscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA)));
|
||||
auto groupscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB)));
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
blockscale_tensor_A.resize(groupscale_a_coord);
|
||||
blockscale_tensor_B.resize(groupscale_b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
cutlass::Distribution::Kind dist_A = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_B = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_C = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_scaleA = cutlass::Distribution::Uniform;
|
||||
cutlass::Distribution::Kind dist_scaleB = cutlass::Distribution::Uniform;
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), dist_A, seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), dist_B, seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), dist_C, seed + 2024);
|
||||
initialize_scale_tensor(blockscale_tensor_A.host_view(), dist_scaleA, seed + 2025);
|
||||
initialize_scale_tensor(blockscale_tensor_B.host_view(), dist_scaleB, seed + 2026);
|
||||
|
||||
#if 0 // Dump blockscaled tensors
|
||||
std::cout << "blockscale_tensor_A: " << groupscale_a_coord << std::endl;
|
||||
std::cout << blockscale_tensor_A.host_view() << "\n";
|
||||
std::cout << "blockscale_tensor_B: " << groupscale_b_coord << std::endl;
|
||||
std::cout << blockscale_tensor_B.host_view() << "\n";
|
||||
#endif
|
||||
|
||||
// Print group scaling tensors on the host side.
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
blockscale_tensor_A.sync_device();
|
||||
blockscale_tensor_B.sync_device();
|
||||
|
||||
// Note : This value has to match the KernelSchedule::ScalePromotionInterval
|
||||
// Else kernel will fail can_implement() check
|
||||
// Deprecation Notice : We plan to remove this params member in an upcoming release
|
||||
// Users can safely delete this line from their code, since the default is already 4
|
||||
mma_promotion_interval = 4;
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.resize(c_coord);
|
||||
tensor_aux.sync_device();
|
||||
tensor_ref_aux.resize(c_coord);
|
||||
}
|
||||
|
||||
if (options.device_scale) {
|
||||
scalar_alpha.resize(cutlass::make_Coord(1));
|
||||
scalar_beta.resize(cutlass::make_Coord(1));
|
||||
scale_A.resize(cutlass::make_Coord(1));
|
||||
scale_B.resize(cutlass::make_Coord(1));
|
||||
scale_C.resize(cutlass::make_Coord(1));
|
||||
scale_D.resize(cutlass::make_Coord(1));
|
||||
scale_aux.resize(cutlass::make_Coord(1));
|
||||
|
||||
cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha);
|
||||
cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta);
|
||||
cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a);
|
||||
cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b);
|
||||
cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c);
|
||||
cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d);
|
||||
cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux);
|
||||
|
||||
scalar_alpha.sync_device();
|
||||
scalar_beta.sync_device();
|
||||
scale_A.sync_device();
|
||||
scale_B.sync_device();
|
||||
scale_C.sync_device();
|
||||
scale_D.sync_device();
|
||||
scale_aux.sync_device();
|
||||
}
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
|
||||
if (IsAuxFp8 && options.save_aux && options.save_amax) {
|
||||
abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
abs_max_aux.sync_device();
|
||||
reference_abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template<typename GemmArguments>
|
||||
GemmArguments args_from_options(const Options<RasterOrderOptions> &options)
|
||||
{
|
||||
GemmArguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(),
|
||||
stride_A,
|
||||
tensor_B.device_data(),
|
||||
stride_B,
|
||||
mma_promotion_interval,
|
||||
blockscale_tensor_A.device_data(),
|
||||
layout_SFA,
|
||||
blockscale_tensor_B.device_data(),
|
||||
layout_SFB
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = scalar_alpha.device_data();
|
||||
fusion_args.beta_ptr = scalar_beta.device_data();
|
||||
fusion_args.scale_a = options.scale_a;
|
||||
fusion_args.scale_b = options.scale_b;
|
||||
fusion_args.scale_c = options.scale_c;
|
||||
fusion_args.scale_a_ptr = scale_A.device_data();
|
||||
fusion_args.scale_b_ptr = scale_B.device_data();
|
||||
fusion_args.scale_c_ptr = scale_C.device_data();
|
||||
|
||||
// ignored if tensor types are not fp8
|
||||
fusion_args.scale_d = options.scale_d;
|
||||
fusion_args.scale_aux = options.scale_aux;
|
||||
fusion_args.scale_d_ptr = scale_D.device_data();
|
||||
fusion_args.scale_aux_ptr = scale_aux.device_data();
|
||||
|
||||
// leaving/setting these as nullptr disables the fusion at runtime
|
||||
fusion_args.bias_ptr = nullptr;
|
||||
|
||||
if (options.save_aux) {
|
||||
fusion_args.aux_ptr = tensor_aux.device_data();
|
||||
fusion_args.dAux = stride_aux;
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_aux_ptr = abs_max_aux.device_data();
|
||||
}
|
||||
}
|
||||
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
/// Don't know why the compiler does not like verify() being templated...
|
||||
bool verify(const Options<RasterOrderOptions> &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.k, options.l),
|
||||
stride_A
|
||||
)
|
||||
);
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.n, options.k, options.l),
|
||||
stride_B
|
||||
)
|
||||
);
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.n, options.l),
|
||||
stride_C
|
||||
)
|
||||
);
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.n, options.l),
|
||||
stride_D
|
||||
)
|
||||
);
|
||||
auto Aux = cute::make_tensor(tensor_ref_aux.host_data(),
|
||||
cute::make_layout(
|
||||
cute::make_shape(options.m, options.n, options.l),
|
||||
stride_aux
|
||||
)
|
||||
);
|
||||
|
||||
auto SFA = cute::make_tensor(blockscale_tensor_A.host_data(), layout_SFA);
|
||||
auto SFB = cute::make_tensor(blockscale_tensor_B.host_data(), layout_SFB);
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
decltype(Aux),
|
||||
unused_t, // valpha
|
||||
unused_t, // vbeta
|
||||
ActivationFunctor
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.Aux = Aux;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
epilogue_params.scale_a = options.scale_a;
|
||||
epilogue_params.scale_b = options.scale_b;
|
||||
epilogue_params.scale_c = options.scale_c;
|
||||
epilogue_params.scale_d = options.scale_d;
|
||||
epilogue_params.scale_aux = options.scale_aux;
|
||||
epilogue_params.abs_max_D = reference_abs_max_D.host_data();
|
||||
epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data();
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
bool passed = true;
|
||||
tensor_D.sync_host();
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_D.host_view(), tensor_ref_D.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
double mse = cutlass::reference::host::TensorMSE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double mre = cutlass::reference::host::TensorMRE(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
double max_error = cutlass::reference::host::TensorGreatestError(tensor_D.host_view(), tensor_ref_D.host_view());
|
||||
std::cout << " Result MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
|
||||
#if 0
|
||||
std::cout << "tensor_ref_D.host_view() {" << std::endl
|
||||
<< tensor_ref_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
std::cout << "tensor_D.host_view() {" << std::endl
|
||||
<< tensor_D.host_view() << std::endl
|
||||
<< "}" << std::endl;
|
||||
#endif
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
std::cout << " Abs max D: " << abs_max_D.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_D.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_D.at(cutlass::make_Coord(0)), reference_abs_max_D.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorRelativelyEquals(tensor_aux.host_view(), tensor_ref_aux.host_view(), ElementAux(options.epsilon), ElementAux(options.non_zero_floor));
|
||||
mse = cutlass::reference::host::TensorMSE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
mre = cutlass::reference::host::TensorMRE(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
max_error = cutlass::reference::host::TensorGreatestError(tensor_aux.host_view(), tensor_ref_aux.host_view());
|
||||
std::cout << " Aux MSE: " << mse << ", MRE: " << mre << ", greatest error: " << max_error << std::endl;
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
std::cout << " Abs max aux: " << abs_max_aux.at(cutlass::make_Coord(0)) << ", reference: " << reference_abs_max_aux.at(cutlass::make_Coord(0)) << std::endl;
|
||||
passed &= cutlass::relatively_equal(abs_max_aux.at(cutlass::make_Coord(0)), reference_abs_max_aux.at(cutlass::make_Coord(0)), ElementScalar(options.epsilon), ElementScalar(options.non_zero_floor));
|
||||
}
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
int run(Options<RasterOrderOptions> &options) {
|
||||
|
||||
bool skip = false;
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
|
||||
|
||||
if (options.m < ScaleGranularityM) {
|
||||
std::cout << " Skippig (m size: " << options.m << " less than ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.n < ScaleGranularityN) {
|
||||
std::cout << " Skippig (n size: " << options.n << " less than ScaleGranularityN: " << ScaleGranularityN << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (options.k < size<2>(TileShape{})) {
|
||||
std::cout << " Skippig (k size: " << options.k << " less than TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
|
||||
skip = true;
|
||||
}
|
||||
|
||||
if (!skip) std::cout << " Running... " << std::endl;
|
||||
else return -1;
|
||||
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
if (options.verify) {
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
else {
|
||||
result.passed = true;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
if (iter == options.warmup)
|
||||
timer.start();
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return result.passed;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
bool passed = true;
|
||||
passed = run(options);
|
||||
if (!passed)
|
||||
return -1;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -30,3 +30,8 @@ cutlass_example_add_executable(
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
|
||||
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
|
||||
)
|
||||
|
||||
@ -34,6 +34,7 @@ template<typename RasterOrderOptions>
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool verify = true;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
|
||||
@ -41,9 +42,12 @@ struct Options {
|
||||
bool save_aux = true;
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int warmup = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
float epsilon = 0.02f;
|
||||
float non_zero_floor = 1.f;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
@ -68,7 +72,11 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("device_scale", device_scale, false);
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("warmup", warmup);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("verify", verify);
|
||||
cmd.get_cmd_line_argument("epsilon", epsilon);
|
||||
cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
@ -89,8 +97,8 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
|
||||
out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n"
|
||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
@ -109,11 +117,14 @@ struct Options {
|
||||
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
|
||||
<< " --verify=<bool> Verify the results.\n\n"
|
||||
<< " --epsilon=<float> The epsilon value for comparing the results.\n\n"
|
||||
<< " --non-zero-floor=<float> The none zero floor for comparing the results.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
<< "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -1,504 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Reference implementation for GETT in host-side code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/relatively_equal.h"
|
||||
#include <iostream>
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::reference::host {
|
||||
|
||||
template<class T, class = void>
|
||||
struct ElementTraits {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
|
||||
using type = decltype(std::declval<T>().get());
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAccumulator_,
|
||||
class TensorA_, // (M, K, L)
|
||||
class TensorB_, // (N, K, L)
|
||||
class TensorScaleA_, // (m, k, L)
|
||||
class TensorScaleB_, // (n, k, L)
|
||||
class TileShape_
|
||||
>
|
||||
struct GettMainloopParams {
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using TensorA = TensorA_;
|
||||
using TensorB = TensorB_;
|
||||
using EngineA = typename TensorA::engine_type;
|
||||
using LayoutA = typename TensorA::layout_type;
|
||||
using EngineB = typename TensorB::engine_type;
|
||||
using LayoutB = typename TensorB::layout_type;
|
||||
|
||||
using TensorScaleA = TensorScaleA_;
|
||||
using TensorScaleB = TensorScaleB_;
|
||||
using TileShape = TileShape_;
|
||||
using EngineScaleA = typename TensorScaleA::engine_type;
|
||||
using EngineScaleB = typename TensorScaleB::engine_type;
|
||||
|
||||
TensorA A{};
|
||||
TensorB B{};
|
||||
TensorScaleA ScaleA{};
|
||||
TensorScaleB ScaleB{};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<
|
||||
class ElementScalar_,
|
||||
class ElementScalingFactor_,
|
||||
class ElementAccumulator_,
|
||||
class ElementCompute_,
|
||||
class TensorC_, // (M, N, L)
|
||||
class TensorD_, // (M, N, L)
|
||||
class VectorBias_ = TensorD_, // (M, 1)
|
||||
class TensorAux_ = TensorD_, // (M, N, L)
|
||||
class VectorAlpha_ = TensorD_, // (M, 1)
|
||||
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
||||
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
|
||||
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
|
||||
bool PerColumnBias_ = false
|
||||
>
|
||||
struct GettEpilogueParams {
|
||||
using ElementScalar = ElementScalar_;
|
||||
using ElementScalingFactor = ElementScalingFactor_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using TensorC = TensorC_;
|
||||
using TensorD = TensorD_;
|
||||
using TensorAux = TensorAux_;
|
||||
using VectorBias = VectorBias_;
|
||||
using VectorAlpha = VectorAlpha_;
|
||||
using VectorBeta = VectorBeta_;
|
||||
using ActivationFunctor = ActivationFunctor_;
|
||||
using BiasBinaryOp = BiasBinaryOp_;
|
||||
|
||||
using EngineC = typename TensorC::engine_type;
|
||||
using LayoutC = typename TensorC::layout_type;
|
||||
using EngineD = typename TensorD::engine_type;
|
||||
using LayoutD = typename TensorD::layout_type;
|
||||
static constexpr bool PerColumnBias = PerColumnBias_;
|
||||
ElementScalar alpha = ElementScalar(1);
|
||||
ElementScalar beta = ElementScalar(0);
|
||||
|
||||
TensorC C{};
|
||||
TensorD D{};
|
||||
VectorBias Bias{};
|
||||
TensorAux Aux{};
|
||||
VectorAlpha Valpha{};
|
||||
VectorBeta Vbeta{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
|
||||
ElementAccumulator* abs_max_D = nullptr;
|
||||
ElementAccumulator* abs_max_Aux = nullptr;
|
||||
|
||||
ElementScalingFactor scale_a = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_b = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_c = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_d = ElementScalingFactor(1);
|
||||
ElementScalingFactor scale_aux = ElementScalingFactor(1);
|
||||
|
||||
bool beta_per_channel_scaling = false;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - General Tensor-Tensor contraction reference kernel with Blockwise scaling
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gett(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
|
||||
static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{});
|
||||
static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{});
|
||||
// printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n");
|
||||
// printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n");
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for collapse(3)
|
||||
#endif
|
||||
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
|
||||
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
|
||||
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
|
||||
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
|
||||
gett_mainloop(mainloop_params, m, n, l, acc);
|
||||
gett_epilogue(epilogue_params, m, n, l, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Mainloop
|
||||
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_mainloop(
|
||||
MainloopParams const& mainloop_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
||||
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
||||
using ElementBlockScaleA = typename ElementTraits<typename MainloopParams::EngineScaleA::value_type>::type;
|
||||
using ElementBlockScaleB = typename ElementTraits<typename MainloopParams::EngineScaleB::value_type>::type;
|
||||
|
||||
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
|
||||
RingOp fma_op;
|
||||
|
||||
multiplies<ElementAccumulator> scale_op;
|
||||
|
||||
static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});;
|
||||
|
||||
// Tempo accumulators to seperate blockwise accumulation
|
||||
typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN];
|
||||
|
||||
// Zero out accumulators
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t block_m = m / kBlockM;
|
||||
int64_t block_n = n / kBlockN;
|
||||
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, l);
|
||||
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l);
|
||||
|
||||
// Compute on this k-block
|
||||
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
||||
|
||||
// Load Blockwise scaling factor from blockscale Tensors for A and B
|
||||
int64_t block_k = k / kBlockK;
|
||||
ElementBlockScaleA scale_a = blockscale_A[block_k];
|
||||
ElementBlockScaleB scale_b = blockscale_B[block_k];
|
||||
|
||||
// Load A
|
||||
ElementAccumulator a_frag[kBlockM];
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
||||
} else {
|
||||
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// Load B
|
||||
ElementAccumulator b_frag[kBlockN];
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
||||
} else {
|
||||
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
||||
}
|
||||
}
|
||||
|
||||
// do compute
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply Blockwise-scaling at kBlockK boundary
|
||||
// (a) Apply block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary
|
||||
// (b) Zero-out partial temporary (acc_temp),
|
||||
// (c) Update permanent (accu)
|
||||
if ((k+1) % kBlockK == 0) {
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a * scale_b;
|
||||
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
|
||||
acc_temp[m_b][n_b] = ElementAccumulator(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GETT - Epilogue
|
||||
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
||||
void gett_epilogue(
|
||||
EpilogueParams const& epilogue_params,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t l,
|
||||
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
||||
{
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
|
||||
|
||||
using cute::raw_pointer_cast;
|
||||
|
||||
using ElementCompute = typename EpilogueParams::ElementCompute;
|
||||
using ElementC = typename EpilogueParams::TensorC::value_type;
|
||||
using ElementD = typename EpilogueParams::TensorD::value_type;
|
||||
using ElementAux = typename EpilogueParams::TensorAux::value_type;
|
||||
using ElementBias = typename EpilogueParams::VectorBias::value_type;
|
||||
using ElementScalar = typename EpilogueParams::ElementScalar;
|
||||
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
|
||||
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
|
||||
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
||||
|
||||
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
||||
constexpr bool IsScalingAndAmaxOutputNeeded =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsReLUAuxNeeded =
|
||||
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
|
||||
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
|
||||
constexpr bool IsClamp =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>;
|
||||
|
||||
constexpr bool IsBackpropFusion =
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
|
||||
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
|
||||
|
||||
// Input related converter
|
||||
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
|
||||
NumericConverter<ElementCompute, ElementC> source_converter;
|
||||
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
||||
|
||||
// Scale related converter
|
||||
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
||||
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
|
||||
|
||||
// Abs max converter
|
||||
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
|
||||
|
||||
// Output related converter
|
||||
NumericConverter<ElementD, ElementCompute> destination_converter;
|
||||
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
||||
NumericConverter<ElementBias, ElementCompute> dBias_converter;
|
||||
|
||||
// Epilogue operations
|
||||
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
|
||||
multiplies<ElementCompute> mul;
|
||||
plus<ElementCompute> add;
|
||||
|
||||
// Activation operation
|
||||
ActivationFunctor activation;
|
||||
|
||||
// Bias binary operation
|
||||
BiasBinaryOp bias_op;
|
||||
|
||||
// Do conversion
|
||||
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
|
||||
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
|
||||
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
|
||||
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
|
||||
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
|
||||
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
|
||||
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
|
||||
|
||||
// Init local var
|
||||
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
|
||||
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
|
||||
|
||||
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
||||
converted_beta = mul(converted_beta, converted_scale_c);
|
||||
|
||||
ElementCompute inter_accum[kBlockM][kBlockN];
|
||||
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
ElementCompute local_dBias = ElementCompute(0);
|
||||
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
|
||||
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
|
||||
// per-row alpha
|
||||
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
|
||||
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b));
|
||||
}
|
||||
ElementCompute output = mul(converted_alpha, converted_acc);
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
|
||||
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
|
||||
output = bias_op(output, converted_bias);
|
||||
}
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.C.data())) {
|
||||
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
|
||||
// per-row beta
|
||||
if (epilogue_params.Vbeta.data()) {
|
||||
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b));
|
||||
}
|
||||
output = epilogue_fma(converted_beta, converted_src, output);
|
||||
}
|
||||
|
||||
if constexpr (IsBackpropFusion) {
|
||||
ElementAux aux_input = ElementAux(0);
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
|
||||
}
|
||||
|
||||
output = activation(output, aux_source_converter(aux_input));
|
||||
local_dBias = add(local_dBias, output);
|
||||
}
|
||||
else {
|
||||
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
||||
auto aux_output = output;
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
|
||||
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
|
||||
}
|
||||
|
||||
if constexpr (IsReLUAuxNeeded) {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
|
||||
} else {
|
||||
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsClamp) { // Treat Clamp as ReLU
|
||||
output = activation(output, {0, std::numeric_limits<ElementCompute>::max()});
|
||||
}
|
||||
else {
|
||||
output = activation(output);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
||||
local_abs_max_output = amax_op(local_abs_max_output, output);
|
||||
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
|
||||
}
|
||||
|
||||
inter_accum[m_b][n_b] = ElementCompute(output);
|
||||
}
|
||||
} // n_b
|
||||
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
|
||||
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
|
||||
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
|
||||
local_dBias = add(local_dBias, converted_dBias);
|
||||
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
|
||||
}
|
||||
}
|
||||
} // m_b
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp critical(Abs_Max_Data_Update)
|
||||
#endif
|
||||
{
|
||||
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_D) {
|
||||
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
||||
if (epilogue_params.abs_max_Aux) {
|
||||
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
|
||||
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM - General Matrix-Matrix contraction without conjugation options
|
||||
template <
|
||||
class MainloopParams,
|
||||
class EpilogueParams
|
||||
>
|
||||
void Gemm3x(
|
||||
MainloopParams const& mainloop_params,
|
||||
EpilogueParams const& epilogue_params)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) "
|
||||
"with Batchmode are supported");
|
||||
// Lower the Matrix-Multiplication with Blockwise scaling (Gemm3x) to a Tensor Contraction (Gett).
|
||||
Gett(mainloop_params, epilogue_params);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // cutlass::reference::host
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,773 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
|
||||
descriptors to move between groups/problem_count (represented by groups).
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
Examples:
|
||||
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling \
|
||||
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
|
||||
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <cfloat>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 128;
|
||||
constexpr int ScaleGranularityK = 128;
|
||||
|
||||
constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutD *, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopWithGroupWiseScaling,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
/// Initialization
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_blockscale_A;
|
||||
std::vector<int64_t> offset_blockscale_B;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFB> layout_SFB_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<ElementD> block_D;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
|
||||
|
||||
cutlass::DeviceAllocation<const ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_D;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
|
||||
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023,
|
||||
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
|
||||
|
||||
double _scope_max, _scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
if (bits_input == 1) {
|
||||
_scope_max = 2;
|
||||
_scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
_scope_max = 2;
|
||||
_scope_min = -2;
|
||||
} else if (bits_input == 16) {
|
||||
_scope_max = 5;
|
||||
_scope_min = -5;
|
||||
} else {
|
||||
_scope_max = 8;
|
||||
_scope_min = -8;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
|
||||
_scope_max = scope_max;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
|
||||
_scope_min = scope_min;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
template <typename OptionType>
|
||||
void allocate(const OptionType &options) {
|
||||
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_blockscale_A = 0;
|
||||
int64_t total_elements_blockscale_B = 0;
|
||||
|
||||
offset_A.clear();
|
||||
offset_B.clear();
|
||||
offset_C.clear();
|
||||
offset_D.clear();
|
||||
offset_blockscale_A.clear();
|
||||
offset_blockscale_B.clear();
|
||||
stride_A_host.clear();
|
||||
stride_B_host.clear();
|
||||
stride_C_host.clear();
|
||||
stride_D_host.clear();
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
|
||||
auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_blockscale_A.push_back(total_elements_blockscale_A);
|
||||
offset_blockscale_B.push_back(total_elements_blockscale_B);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA));
|
||||
int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB));
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_blockscale_A += elements_blockscale_A;
|
||||
total_elements_blockscale_B += elements_blockscale_B;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
layout_SFA_host.push_back(group_layout_SFA);
|
||||
layout_SFB_host.push_back(group_layout_SFB);
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
blockscale_block_A.reset(total_elements_blockscale_A);
|
||||
blockscale_block_B.reset(total_elements_blockscale_B);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
template <typename OptionType>
|
||||
void initialize(const OptionType &options) {
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementD *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
|
||||
|
||||
alpha_host.clear();
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_blockscale_A.reset(options.groups);
|
||||
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
|
||||
|
||||
ptr_blockscale_B.reset(options.groups);
|
||||
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
layout_SFA.reset(options.groups);
|
||||
layout_SFA.copy_from_host(layout_SFA_host.data());
|
||||
|
||||
layout_SFB.reset(options.groups);
|
||||
layout_SFB.copy_from_host(layout_SFB_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2022);
|
||||
initialize_block(block_B, seed + 2023);
|
||||
initialize_block(block_C, seed + 2024);
|
||||
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
|
||||
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template<typename GemmArguments, typename OptionType>
|
||||
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
int device_id = 0;
|
||||
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename Gemm::GemmKernel>(device_id);
|
||||
|
||||
GemmArguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_blockscale_A.get(), layout_SFA.get(),
|
||||
ptr_blockscale_B.get(), layout_SFB.get()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
ptr_C.get(), stride_C.get(),
|
||||
ptr_D.get(), stride_D.get()
|
||||
},
|
||||
kernel_hw_info
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename OptionType>
|
||||
bool verify(const OptionType &options) {
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
std::vector<ElementB> block_B_host(block_B.size());
|
||||
std::vector<ElementC> block_C_host(block_C.size());
|
||||
std::vector<ElementD> block_D_host_kernel(block_D.size());
|
||||
std::vector<ElementD> block_D_host_ref(block_D.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
|
||||
|
||||
block_A.copy_to_host(block_A_host.data());
|
||||
block_B.copy_to_host(block_B_host.data());
|
||||
block_C.copy_to_host(block_C_host.data());
|
||||
block_D.copy_to_host(block_D_host_kernel.data());
|
||||
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, k, 1),
|
||||
stride_A_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(n, k, 1),
|
||||
stride_B_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_C_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_D_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
|
||||
auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
|
||||
layout_SFA_host.at(group_idx));
|
||||
auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
|
||||
layout_SFB_host.at(group_idx));
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
unused_t, // Aux
|
||||
unused_t, // valpha
|
||||
unused_t // vbeta
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = alpha_host.at(group_idx);
|
||||
epilogue_params.beta = beta_host.at(group_idx);
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
auto this_group_passed = std::equal(
|
||||
// std::execution::par_unseq,
|
||||
block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
|
||||
block_D_host_kernel.data() + offset_D.at(group_idx)
|
||||
);
|
||||
|
||||
passed &= this_group_passed;
|
||||
|
||||
#if 0
|
||||
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename OptionType>
|
||||
int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running tests with host problem shapes:" << std::endl;
|
||||
run(options, true);
|
||||
std::cout << "Running tests without host problem shapes:" << std::endl;
|
||||
run(options, false);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,781 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture
|
||||
This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0.
|
||||
APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows:
|
||||
1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA)
|
||||
which are more efficient than the Ampere tensor core instructions.
|
||||
2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large
|
||||
blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous
|
||||
copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA
|
||||
descriptors to move between groups/problem_count (represented by groups).
|
||||
3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details).
|
||||
4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the
|
||||
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
|
||||
improve performance.
|
||||
5. This example is tuned specifically for the sparse groups case, where the number of active groups (groups
|
||||
with non-zero problem count) is much smaller than the total number of groups.
|
||||
Examples:
|
||||
$ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups \
|
||||
--m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \
|
||||
--raster=h --swizzle=2 --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <cfloat>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
// Includes from examples directory
|
||||
#include "helper.h"
|
||||
#include "hopper_fp8_commandline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementBlockScale = float; // Element type for blockscaling during accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
using TileShape = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
static constexpr int ScaleGranularityM = 1;
|
||||
static constexpr int ScaleGranularityN = 128;
|
||||
static constexpr int ScaleGranularityK = 128;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
|
||||
|
||||
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutD *, AlignmentD,
|
||||
EpilogueSchedule,
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, cute::tuple<LayoutA *, LayoutSFA *>, AlignmentA,
|
||||
ElementB, cute::tuple<LayoutB *, LayoutSFB *>, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
|
||||
>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopWithGroupWiseScaling,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
/// Initialization
|
||||
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_blockscale_A;
|
||||
std::vector<int64_t> offset_blockscale_B;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFB> layout_SFB_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<ElementD> block_D;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_A;
|
||||
cutlass::DeviceAllocation<ElementBlockScale> blockscale_block_B;
|
||||
|
||||
cutlass::DeviceAllocation<const ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_D;
|
||||
cutlass::DeviceAllocation<ElementD *> ptr_ref_D;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_A;
|
||||
cutlass::DeviceAllocation<const ElementBlockScale *> ptr_blockscale_B;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
|
||||
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
double gbps;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
double gbps = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), gbps(gbps), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element, class ScopeMin = std::nullopt_t, class ScopeMax = std::nullopt_t>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023,
|
||||
ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) {
|
||||
|
||||
double _scope_max, _scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
if (bits_input == 1) {
|
||||
_scope_max = 2;
|
||||
_scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
_scope_max = 2;
|
||||
_scope_min = -2;
|
||||
} else if (bits_input == 16) {
|
||||
_scope_max = 5;
|
||||
_scope_min = -5;
|
||||
} else {
|
||||
_scope_max = 8;
|
||||
_scope_min = -8;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMax, std::nullopt_t>) {
|
||||
_scope_max = scope_max;
|
||||
}
|
||||
if constexpr (!std::is_same_v<ScopeMin, std::nullopt_t>) {
|
||||
_scope_min = scope_min;
|
||||
}
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
template <typename OptionType>
|
||||
void allocate(const OptionType &options) {
|
||||
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_blockscale_A = 0;
|
||||
int64_t total_elements_blockscale_B = 0;
|
||||
|
||||
offset_A.clear();
|
||||
offset_B.clear();
|
||||
offset_C.clear();
|
||||
offset_D.clear();
|
||||
offset_blockscale_A.clear();
|
||||
offset_blockscale_B.clear();
|
||||
stride_A_host.clear();
|
||||
stride_B_host.clear();
|
||||
stride_C_host.clear();
|
||||
stride_D_host.clear();
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_after_alignment_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto group_layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
|
||||
auto group_layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_blockscale_A.push_back(total_elements_blockscale_A);
|
||||
offset_blockscale_B.push_back(total_elements_blockscale_B);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_blockscale_A = size(filter_zeros(group_layout_SFA));
|
||||
int64_t elements_blockscale_B = size(filter_zeros(group_layout_SFB));
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_blockscale_A += elements_blockscale_A;
|
||||
total_elements_blockscale_B += elements_blockscale_B;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
layout_SFA_host.push_back(group_layout_SFA);
|
||||
layout_SFB_host.push_back(group_layout_SFB);
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
blockscale_block_A.reset(total_elements_blockscale_A);
|
||||
blockscale_block_B.reset(total_elements_blockscale_B);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
template <typename OptionType>
|
||||
void initialize(const OptionType &options) {
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_after_alignment_host.data());
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementD *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_A_host(options.groups);
|
||||
std::vector<ElementBlockScale *> ptr_blockscale_B_host(options.groups);
|
||||
|
||||
alpha_host.clear();
|
||||
beta_host.clear();
|
||||
|
||||
for (int i = 0; i < options.groups; i++) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i);
|
||||
ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_blockscale_A.reset(options.groups);
|
||||
ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data());
|
||||
|
||||
ptr_blockscale_B.reset(options.groups);
|
||||
ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
layout_SFA.reset(options.groups);
|
||||
layout_SFA.copy_from_host(layout_SFA_host.data());
|
||||
|
||||
layout_SFB.reset(options.groups);
|
||||
layout_SFB.copy_from_host(layout_SFB_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2022);
|
||||
initialize_block(block_B, seed + 2023);
|
||||
initialize_block(block_C, seed + 2024);
|
||||
initialize_block(blockscale_block_A, seed + 2025, -1, 1);
|
||||
initialize_block(blockscale_block_B, seed + 2026, -1, 1);
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template<typename GemmArguments, typename OptionType>
|
||||
GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
int device_id = 0;
|
||||
cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info<typename Gemm::GemmKernel>(device_id);
|
||||
|
||||
GemmArguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_after_alignment_host.data() : (decltype(options.problem_sizes_after_alignment_host.data())) nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_blockscale_A.get(), layout_SFA.get(),
|
||||
ptr_blockscale_B.get(), layout_SFB.get()
|
||||
},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
ptr_C.get(), stride_C.get(),
|
||||
ptr_D.get(), stride_D.get()
|
||||
},
|
||||
kernel_hw_info
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
arguments.scheduler.raster_order = options.raster;
|
||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename OptionType>
|
||||
bool verify(const OptionType &options) {
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
std::vector<ElementA> block_A_host(block_A.size());
|
||||
std::vector<ElementB> block_B_host(block_B.size());
|
||||
std::vector<ElementC> block_C_host(block_C.size());
|
||||
std::vector<ElementD> block_D_host_kernel(block_D.size());
|
||||
std::vector<ElementD> block_D_host_ref(block_D.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_A_host(blockscale_block_A.size());
|
||||
std::vector<ElementBlockScale> blockscale_block_B_host(blockscale_block_B.size());
|
||||
|
||||
block_A.copy_to_host(block_A_host.data());
|
||||
block_B.copy_to_host(block_B_host.data());
|
||||
block_C.copy_to_host(block_C_host.data());
|
||||
block_D.copy_to_host(block_D_host_kernel.data());
|
||||
blockscale_block_A.copy_to_host(blockscale_block_A_host.data());
|
||||
blockscale_block_B.copy_to_host(blockscale_block_B_host.data());
|
||||
|
||||
bool passed = true;
|
||||
for (int group_idx = 0; group_idx < options.groups; group_idx++) {
|
||||
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
|
||||
auto [m, n, k] = options.problem_sizes_after_alignment_host.at(group_idx);
|
||||
auto gemm_problem_shape = cute::make_shape(m, n, k);
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, k, 1),
|
||||
stride_A_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(n, k, 1),
|
||||
stride_B_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_C_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
cute::make_layout(
|
||||
cute::make_shape(m, n, 1),
|
||||
stride_D_host.at(group_idx)
|
||||
)
|
||||
);
|
||||
|
||||
auto SFA = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx),
|
||||
layout_SFA_host.at(group_idx));
|
||||
auto SFB = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx),
|
||||
layout_SFB_host.at(group_idx));
|
||||
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator,
|
||||
decltype(A),
|
||||
decltype(SFA),
|
||||
decltype(B),
|
||||
decltype(SFB)
|
||||
> mainloop_params{A, SFA, B, SFB};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D)
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.alpha = alpha_host.at(group_idx);
|
||||
epilogue_params.beta = beta_host.at(group_idx);
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
auto this_group_passed = std::equal(
|
||||
// std::execution::par_unseq,
|
||||
block_D_host_ref.data() + offset_D.at(group_idx),
|
||||
block_D_host_ref.data() + offset_D.at(group_idx) + m * n,
|
||||
block_D_host_kernel.data() + offset_D.at(group_idx)
|
||||
);
|
||||
|
||||
passed &= this_group_passed;
|
||||
|
||||
#if 0
|
||||
std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename OptionType>
|
||||
int run(OptionType &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<typename Gemm::Arguments>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0) {
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
result.gbps = options.template gbps<ElementA,
|
||||
ElementB,
|
||||
ElementC,
|
||||
ElementD,
|
||||
ElementBlockScale,
|
||||
TileShape,
|
||||
ScaleMsPerTile,
|
||||
ScaleNsPerTile>(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::string raster = "Heuristic";
|
||||
|
||||
if (options.raster == RasterOrderOptions::AlongN) {
|
||||
raster = "Along N";
|
||||
}
|
||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
||||
raster = "Along M";
|
||||
}
|
||||
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
|
||||
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
|
||||
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
|
||||
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture or "
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options<RasterOrderOptions, ProblemShape> options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
run(options, true);
|
||||
|
||||
std::cout << "Running tests without host problem shapes:" << std::endl;
|
||||
run(options, false);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=512 --groups=50 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
|
||||
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
)
|
||||
|
||||
# MSVC will fail to compile this example with the following error:
|
||||
# fatal error C1083: Cannot open source file: <Some Mojibake>: No such file or directory [...\examples\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.vcxproj]
|
||||
# This is a known issue and we are working on a fix.
|
||||
if (NOT MSVC)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups
|
||||
68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
)
|
||||
|
||||
endif()
|
||||
@ -0,0 +1,271 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
// Command line options parsing
|
||||
template<typename _RasterOrderOptions, typename _ProblemShape>
|
||||
struct Options {
|
||||
|
||||
using RasterOrderOptions = _RasterOrderOptions;
|
||||
using ProblemShape = _ProblemShape;
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, groups = 10;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_after_alignment_host;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
|
||||
int const k_alignment = 128;
|
||||
int const m_alignment = 128;
|
||||
int const n_alignment = 128;
|
||||
|
||||
RasterOrderOptions raster;
|
||||
int swizzle;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster = RasterOrderOptions::AlongM;
|
||||
}
|
||||
else if (raster_char == 'H' || raster_char == 'h') {
|
||||
raster = RasterOrderOptions::Heuristic;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_after_alignment_host.clear();
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_after_alignment_host.reserve(groups);
|
||||
problem_sizes_host.reserve(groups);
|
||||
for (int i = groups; i > 0; i--) {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1);
|
||||
}
|
||||
problem_sizes_after_alignment_host.push_back({m, n, k});
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent_after_alignment, extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
extent.at(i) = x;
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent_after_alignment.at(i) = x;
|
||||
}
|
||||
|
||||
problem_sizes_after_alignment_host.push_back({extent_after_alignment.m(), extent_after_alignment.n(), extent_after_alignment.k()});
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_after_alignment_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Calculate memory bandwidth statistics
|
||||
template <class ElementA,
|
||||
class ElementB,
|
||||
class ElementC,
|
||||
class ElementD,
|
||||
class ElementBlockScale,
|
||||
class TileShape,
|
||||
int ScaleMsPerTile,
|
||||
int ScaleNsPerTile>
|
||||
auto gbps(double runtime_s) const {
|
||||
double total_read_bytes = 0;
|
||||
double total_write_bytes = 0;
|
||||
|
||||
// Calculate bytes read and written for each problem
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
auto problem = problem_sizes_host.at(i);
|
||||
auto M = cute::get<0>(problem);
|
||||
auto N = cute::get<1>(problem);
|
||||
auto K = cute::get<2>(problem);
|
||||
|
||||
if (M > 0) { // Only count active problems
|
||||
// Matrix A: M*K elements read
|
||||
total_read_bytes += M * K * sizeof(ElementA);
|
||||
|
||||
// Matrix B: K*N elements read
|
||||
total_read_bytes += K * N * sizeof(ElementB);
|
||||
|
||||
// Matrix C: M*N elements read (for beta operation)
|
||||
total_read_bytes += M * N * sizeof(ElementC);
|
||||
|
||||
// Block scales for A and B
|
||||
auto blockscale_shape = cute::shape(cute::get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{})));
|
||||
auto blockscale_m = cute::get<0>(blockscale_shape);
|
||||
auto blockscale_n = cute::get<1>(blockscale_shape);
|
||||
auto blockscale_k = cute::get<2>(blockscale_shape);
|
||||
auto groupscale_m = blockscale_m * ScaleMsPerTile;
|
||||
auto groupscale_n = blockscale_n * ScaleNsPerTile;
|
||||
|
||||
total_read_bytes += groupscale_m * blockscale_k * sizeof(ElementBlockScale); // A scales
|
||||
total_read_bytes += groupscale_n * blockscale_k * sizeof(ElementBlockScale); // B scales
|
||||
|
||||
// Matrix D: M*N elements written
|
||||
total_write_bytes += M * N * sizeof(ElementD);
|
||||
}
|
||||
}
|
||||
|
||||
return (total_read_bytes + total_write_bytes) / 1.0e9 / runtime_s;
|
||||
}
|
||||
|
||||
double bandwidth_util(double eff_bandwidth) const {
|
||||
int memoryClockRate;
|
||||
int memoryBusWidth;
|
||||
cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, 0);
|
||||
cudaDeviceGetAttribute(&memoryBusWidth, cudaDevAttrGlobalMemoryBusWidth , 0);
|
||||
double bw = 2.0 * memoryClockRate * (memoryBusWidth / 8) / 1.0e6;
|
||||
return eff_bandwidth / bw * 100.0;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n"
|
||||
<< " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n"
|
||||
<< " --swizzle=<int> CTA Rasterization swizzle\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = 0ull;
|
||||
|
||||
for (auto const [m, n, k] : problem_sizes_host) {
|
||||
fmas += static_cast<uint64_t>(m) *
|
||||
static_cast<uint64_t>(n) *
|
||||
static_cast<uint64_t>(k);
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,831 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
See 55_hopper_int4_bf16_gemm.cu for more details about W4A16 GEMMs with layout shuffling.
|
||||
|
||||
Limitations:
|
||||
1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported.
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <typeinfo>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "grouped_mixed_dtype_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using MmaType = cutlass::bfloat16_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
|
||||
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
// using ValueShuffle = Layout<_1>; // no value reordering
|
||||
using ValueShuffle = Layout<Shape<_2,_4>, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7]
|
||||
int constexpr NumShuffleAtoms = 1;
|
||||
using MmaAtomShape = Layout<Shape<_1,Int<NumShuffleAtoms>>>;
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType, MmaAtomShape, ValueShuffle>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,Int<1>>, StrideB>{}));
|
||||
|
||||
using ElementZero = cutlass::bfloat16_t;
|
||||
using ElementScale = cutlass::bfloat16_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopConvertOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
|
||||
|
||||
using CollectiveMainloopConvertOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Reordered *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopConvertOnlyShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnlyShuffled>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using CollectiveMainloopScaleOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Reordered *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnlyShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnlyShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnlyShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelConvertOnly::InternalStrideC;
|
||||
using StrideD = typename GemmKernelConvertOnly::InternalStrideD;
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_B_dq;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_scale;
|
||||
std::vector<int64_t> offset_zero;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<StrideC_ref> stride_C_host_ref;
|
||||
std::vector<StrideD_ref> stride_D_host_ref;
|
||||
std::vector<StrideS> stride_S_host;
|
||||
std::vector<StrideS_ref> stride_S_host_ref;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<MmaType> block_A;
|
||||
cutlass::DeviceAllocation<QuantType> block_B;
|
||||
cutlass::DeviceAllocation<MmaType> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_A;
|
||||
cutlass::DeviceAllocation<const QuantType *> ptr_B;
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
|
||||
cutlass::DeviceAllocation<const ElementScale *> ptr_scale;
|
||||
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<LayoutB_Reordered> layout_B_reordered;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
|
||||
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
|
||||
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
|
||||
cutlass::DeviceAllocation<StrideS> stride_S;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options : GroupedMixedDtypeOptions<QuantType> {
|
||||
using Base = GroupedMixedDtypeOptions<QuantType>;
|
||||
|
||||
bool shuffle = true;
|
||||
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("shuffle", shuffle);
|
||||
|
||||
this->Base::parse(argc, args);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "69_hopper_int4_bf16_grouped_gemm\n\n"
|
||||
<< " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform\n\n"
|
||||
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "69_hopper_int4_bf16_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(Options const& options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_B_dq = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_scale = 0;
|
||||
int64_t total_elements_zero = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
offset_B_dq.push_back(total_elements_B_dq);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_scale.push_back(total_elements_scale);
|
||||
offset_zero.push_back(total_elements_zero);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N ;
|
||||
int64_t elements_B_dq = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_scale = scale_k * N;
|
||||
int64_t elements_zero = scale_k * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_B_dq += elements_B_dq;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_scale += elements_scale;
|
||||
total_elements_zero += elements_zero;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
|
||||
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
|
||||
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
|
||||
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
|
||||
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_B_dq.reset(total_elements_B_dq);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_scale.reset(total_elements_scale);
|
||||
block_zero.reset(total_elements_zero);
|
||||
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options &options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<MmaType *> ptr_A_host(options.groups);
|
||||
std::vector<QuantType *> ptr_B_host(options.groups);
|
||||
std::vector<MmaType *> ptr_B_dq_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<ElementScale *> ptr_scale_host(options.groups);
|
||||
std::vector<ElementZero *> ptr_zero_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i);
|
||||
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_B_dq.reset(options.groups);
|
||||
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_scale.reset(options.groups);
|
||||
ptr_scale.copy_from_host(ptr_scale_host.data());
|
||||
|
||||
ptr_zero.reset(options.groups);
|
||||
ptr_zero.copy_from_host(ptr_zero_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
stride_C_ref.reset(options.groups);
|
||||
stride_C_ref.copy_from_host(stride_C_host_ref.data());
|
||||
|
||||
stride_D_ref.reset(options.groups);
|
||||
stride_D_ref.copy_from_host(stride_D_host_ref.data());
|
||||
|
||||
stride_S_ref.reset(options.groups);
|
||||
stride_S_ref.copy_from_host(stride_S_host_ref.data());
|
||||
|
||||
stride_S.reset(options.groups);
|
||||
stride_S.copy_from_host(stride_S_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.c, stream);
|
||||
}
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
|
||||
if (options.shuffle) {
|
||||
std::vector<LayoutB_Reordered> layout_B_reordered_host(options.groups);
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]);
|
||||
if (i == 0) {
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered_host[0]);
|
||||
print("\n");
|
||||
}
|
||||
}
|
||||
layout_B_reordered.reset(options.groups);
|
||||
layout_B_reordered.copy_from_host(layout_B_reordered_host.data());
|
||||
}
|
||||
|
||||
// Reverse MN -> NM for SwapAB
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto [M, N, K] = options.problem_sizes_host[i];
|
||||
options.problem_sizes_host[i] = make_tuple(N, M, K);
|
||||
}
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
using Args = typename Gemm::Arguments;
|
||||
auto&& dB = [&]() {
|
||||
// NOTE: add GemmScaleWithZeroPointShuffled
|
||||
if constexpr (cute::is_same_v<Gemm, GemmConvertOnlyShuffled> ||
|
||||
cute::is_same_v<Gemm, GemmScaleOnlyShuffled>) {
|
||||
// offline swizzling is enabled.
|
||||
return layout_B_reordered.get();
|
||||
}
|
||||
else {
|
||||
return stride_B.get();
|
||||
}
|
||||
}();
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) {
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) {
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
bool passed = true;
|
||||
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
|
||||
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
|
||||
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
|
||||
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
|
||||
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
if (M == 0) {
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
StrideA_verif stride_A_verif;
|
||||
StrideB_verif stride_B_verif;
|
||||
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
|
||||
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
std::cout << "We passed all checks\n";
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
std::cout << "Running in no scale mode." << std::endl;
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmConvertOnlyShuffled>(options, false);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmConvertOnly>(options, false);
|
||||
}
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmScaleOnlyShuffled>(options, false);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmScaleOnly>(options, false);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,766 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
See 55_hopper_int4_fp8_gemm.cu for more details about W4A8 GEMMs with lookup table.
|
||||
|
||||
Limitations:
|
||||
1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported.
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <typeinfo>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "grouped_mixed_dtype_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using MmaType = cutlass::float_e4m3_t;
|
||||
using QuantType = cutlass::int4b_t;
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
// Need to pass a pointer type to make the 3rd dimension of Stride be _0
|
||||
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
|
||||
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
|
||||
|
||||
// Define the CuTe layout for reoredered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory.
|
||||
// It specifies the reordering within a single warp's fragment
|
||||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout<Shape<int,int,Int<1>>, StrideB>{}));
|
||||
|
||||
using ElementZero = cutlass::float_e4m3_t;
|
||||
using ElementScale = cutlass::float_e4m3_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, LayoutB_Reordered *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopShuffled,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
|
||||
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
|
||||
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_B_dq;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_scale;
|
||||
std::vector<int64_t> offset_zero;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<StrideC_ref> stride_C_host_ref;
|
||||
std::vector<StrideD_ref> stride_D_host_ref;
|
||||
std::vector<StrideS> stride_S_host;
|
||||
std::vector<StrideS_ref> stride_S_host_ref;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<MmaType> block_A;
|
||||
cutlass::DeviceAllocation<QuantType> block_B;
|
||||
cutlass::DeviceAllocation<ElementB> block_B_modified;
|
||||
cutlass::DeviceAllocation<MmaType> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<cutlass::Array<ElementScale, 8>> block_scale_packed;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_A;
|
||||
cutlass::DeviceAllocation<const QuantType *> ptr_B;
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
|
||||
cutlass::DeviceAllocation<const cutlass::Array<ElementScale, 8> *> ptr_scale_packed;
|
||||
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename GemmScaleOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<LayoutB_Reordered> layout_B_reordered;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
|
||||
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
|
||||
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
|
||||
cutlass::DeviceAllocation<StrideS> stride_S;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options : GroupedMixedDtypeOptions<QuantType> {
|
||||
using Base = GroupedMixedDtypeOptions<QuantType>;
|
||||
|
||||
bool shuffle = true;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("shuffle", shuffle);
|
||||
|
||||
this->Base::parse(argc, args);
|
||||
|
||||
mode = 1; // override the mode value to always be scale only mode
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "69_hopper_int4_fp8_grouped_gemm\n\n"
|
||||
<< " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --c=<int> The size of each chunk for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --warmup=<int> Number of warmup iterations to perform\n\n"
|
||||
<< " --shuffle=<boolean> Enable the offline layout swizzling.\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size.\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "69_hopper_int4_fp8_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
||||
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
||||
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(Options const& options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_B_dq = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_scale = 0;
|
||||
int64_t total_elements_zero = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
offset_B_dq.push_back(total_elements_B_dq);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_scale.push_back(total_elements_scale);
|
||||
offset_zero.push_back(total_elements_zero);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N ;
|
||||
int64_t elements_B_dq = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_scale = scale_k * N;
|
||||
int64_t elements_zero = scale_k * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_B_dq += elements_B_dq;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_scale += elements_scale;
|
||||
total_elements_zero += elements_zero;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
|
||||
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
|
||||
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
|
||||
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
|
||||
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_B_modified.reset(total_elements_B);
|
||||
block_B_dq.reset(total_elements_B_dq);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_scale.reset(total_elements_scale);
|
||||
block_scale_packed.reset(total_elements_scale);
|
||||
block_zero.reset(total_elements_zero);
|
||||
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options& options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<MmaType *> ptr_A_host(options.groups);
|
||||
std::vector<QuantType *> ptr_B_host(options.groups);
|
||||
std::vector<MmaType *> ptr_B_dq_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<cutlass::Array<ElementScale, 8> *> ptr_scale_packed_host(options.groups);
|
||||
std::vector<ElementZero *> ptr_zero_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B_modified.get() + offset_B.at(i);
|
||||
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_scale_packed_host.at(i) = block_scale_packed.get() + offset_scale.at(i);
|
||||
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_B_dq.reset(options.groups);
|
||||
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_scale_packed.reset(options.groups);
|
||||
ptr_scale_packed.copy_from_host(ptr_scale_packed_host.data());
|
||||
|
||||
ptr_zero.reset(options.groups);
|
||||
ptr_zero.copy_from_host(ptr_zero_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
stride_C_ref.reset(options.groups);
|
||||
stride_C_ref.copy_from_host(stride_C_host_ref.data());
|
||||
|
||||
stride_D_ref.reset(options.groups);
|
||||
stride_D_ref.copy_from_host(stride_D_host_ref.data());
|
||||
|
||||
stride_S_ref.reset(options.groups);
|
||||
stride_S_ref.copy_from_host(stride_S_host_ref.data());
|
||||
|
||||
stride_S.reset(options.groups);
|
||||
stride_S.copy_from_host(stride_S_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size());
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size());
|
||||
initialize_zero(block_zero, options);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
|
||||
if (options.shuffle) {
|
||||
std::vector<LayoutB_Reordered> layout_B_reordered_host(options.groups);
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{});
|
||||
auto layout_B = make_layout(shape_B, stride_B_host.at(i));
|
||||
// Repeat the reorder layout atom to tile the whole tensor shape
|
||||
layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
cutlass::reorder_tensor(block_B_modified.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]);
|
||||
if (i == 0) {
|
||||
print("Quantized tensor layout: ");
|
||||
print(layout_B_reordered_host[0]);
|
||||
print("\n");
|
||||
}
|
||||
}
|
||||
layout_B_reordered.reset(options.groups);
|
||||
layout_B_reordered.copy_from_host(layout_B_reordered_host.data());
|
||||
}
|
||||
|
||||
// Reverse MN -> NM for SwapAB
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto [M, N, K] = options.problem_sizes_host[i];
|
||||
options.problem_sizes_host[i] = make_tuple(N, M, K);
|
||||
}
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
using Args = typename Gemm::Arguments;
|
||||
auto&& dB = [&]() {
|
||||
if constexpr (cute::is_same_v<Gemm, GemmShuffled>) { // offline swizzling is enabled.
|
||||
return layout_B_reordered.get();
|
||||
}
|
||||
else {
|
||||
return stride_B.get();
|
||||
}
|
||||
}();
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
Args arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
arguments = Args {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.c},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
return arguments;
|
||||
}
|
||||
|
||||
|
||||
bool verify(Options const& options) {
|
||||
bool passed = true;
|
||||
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
|
||||
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
|
||||
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
|
||||
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
|
||||
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
if (M == 0) {
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
StrideA_verif stride_A_verif;
|
||||
StrideB_verif stride_B_verif;
|
||||
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
|
||||
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
std::cout << "We passed all checks\n";
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
std::cout << "Running in per-column scale mode." << std::endl;
|
||||
if (options.shuffle) {
|
||||
std::cout << "Offline shuffle enabled." << std::endl;
|
||||
run<GemmShuffled>(options, false);
|
||||
} else {
|
||||
std::cout << "Offline shuffle disabled." << std::endl;
|
||||
run<GemmScaleOnly>(options, false);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,691 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
|
||||
See 55_hopper_mixed_dtype_gemm.cu for more details about Mixed-input GEMMs.
|
||||
|
||||
Limitations:
|
||||
1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported.
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <typeinfo>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "grouped_mixed_dtype_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using MmaType = cutlass::bfloat16_t;
|
||||
using QuantType = cutlass::float_e5m2_t;
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
constexpr int TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType;
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using ElementZero = cutlass::bfloat16_t;
|
||||
using ElementScale = cutlass::bfloat16_t;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_128,_16,cute::Int<TileShapeK>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type *, AlignmentC,
|
||||
ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type *, AlignmentD,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementB, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopConvertOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelConvertOnly>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES ===========================================================================
|
||||
// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information.
|
||||
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, ElementScale>, LayoutB_Transpose *, AlignmentB,
|
||||
ElementA, LayoutA_Transpose *, AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloopScaleOnly,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
|
||||
|
||||
using StrideA = typename GemmConvertOnly::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename GemmConvertOnly::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename GemmConvertOnly::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename GemmConvertOnly::GemmKernel::InternalStrideD;
|
||||
using StrideC_ref = cutlass::detail::TagToStrideC_t<LayoutC>;
|
||||
using StrideD_ref = cutlass::detail::TagToStrideC_t<LayoutD>;
|
||||
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
|
||||
using StrideS_ref = cutlass::detail::TagToStrideB_t<LayoutScale>;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_B_dq;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
std::vector<int64_t> offset_scale;
|
||||
std::vector<int64_t> offset_zero;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
std::vector<StrideC_ref> stride_C_host_ref;
|
||||
std::vector<StrideD_ref> stride_D_host_ref;
|
||||
std::vector<StrideS> stride_S_host;
|
||||
std::vector<StrideS_ref> stride_S_host_ref;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<MmaType> block_A;
|
||||
cutlass::DeviceAllocation<QuantType> block_B;
|
||||
cutlass::DeviceAllocation<MmaType> block_B_dq;
|
||||
cutlass::DeviceAllocation<ElementScale> block_scale;
|
||||
cutlass::DeviceAllocation<ElementZero> block_zero;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_A;
|
||||
cutlass::DeviceAllocation<const QuantType *> ptr_B;
|
||||
cutlass::DeviceAllocation<const MmaType *> ptr_B_dq;
|
||||
cutlass::DeviceAllocation<const ElementScale *> ptr_scale;
|
||||
cutlass::DeviceAllocation<const ElementZero *> ptr_zero;
|
||||
cutlass::DeviceAllocation<const ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename GemmConvertOnly::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
cutlass::DeviceAllocation<StrideC_ref> stride_C_ref;
|
||||
cutlass::DeviceAllocation<StrideD_ref> stride_D_ref;
|
||||
cutlass::DeviceAllocation<StrideS_ref> stride_S_ref;
|
||||
cutlass::DeviceAllocation<StrideS> stride_S;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using Options = GroupedMixedDtypeOptions<QuantType>;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(Options const& options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_B_dq = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
int64_t total_elements_scale = 0;
|
||||
int64_t total_elements_zero = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B * cutlass::sizeof_bits<QuantType>::value / 8);
|
||||
offset_B_dq.push_back(total_elements_B_dq);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
offset_scale.push_back(total_elements_scale);
|
||||
offset_zero.push_back(total_elements_zero);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N ;
|
||||
int64_t elements_B_dq = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
int64_t elements_scale = scale_k * N;
|
||||
int64_t elements_zero = scale_k * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_B_dq += elements_B_dq;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
total_elements_scale += elements_scale;
|
||||
total_elements_zero += elements_zero;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1}));
|
||||
stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1}));
|
||||
stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1}));
|
||||
stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1}));
|
||||
stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1}));
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_B_dq.reset(total_elements_B_dq);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_scale.reset(total_elements_scale);
|
||||
block_zero.reset(total_elements_zero);
|
||||
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(Options &options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<MmaType *> ptr_A_host(options.groups);
|
||||
std::vector<QuantType *> ptr_B_host(options.groups);
|
||||
std::vector<MmaType *> ptr_B_dq_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<ElementScale *> ptr_scale_host(options.groups);
|
||||
std::vector<ElementZero *> ptr_zero_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i);
|
||||
ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_B_dq.reset(options.groups);
|
||||
ptr_B_dq.copy_from_host(ptr_B_dq_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_scale.reset(options.groups);
|
||||
ptr_scale.copy_from_host(ptr_scale_host.data());
|
||||
|
||||
ptr_zero.reset(options.groups);
|
||||
ptr_zero.copy_from_host(ptr_zero_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
stride_C_ref.reset(options.groups);
|
||||
stride_C_ref.copy_from_host(stride_C_host_ref.data());
|
||||
|
||||
stride_D_ref.reset(options.groups);
|
||||
stride_D_ref.copy_from_host(stride_D_host_ref.data());
|
||||
|
||||
stride_S_ref.reset(options.groups);
|
||||
stride_S_ref.copy_from_host(stride_S_host_ref.data());
|
||||
|
||||
stride_S.reset(options.groups);
|
||||
stride_S.copy_from_host(stride_S_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_tensor(block_A, seed + 2023);
|
||||
initialize_tensor(block_B, seed + 2022);
|
||||
initialize_tensor(block_C, seed + 2021);
|
||||
initialize_scale(block_scale, options);
|
||||
initialize_zero(block_zero, options);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
// Reverse MN -> NM for SwapAB
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto [M, N, K] = options.problem_sizes_host[i];
|
||||
options.problem_sizes_host[i] = make_tuple(N, M, K);
|
||||
}
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||
}
|
||||
|
||||
if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.c},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(Options const& options) {
|
||||
bool passed = true;
|
||||
|
||||
constexpr bool IsFP8Input = cute::is_same_v<MmaType, cutlass::float_e4m3_t> || cute::is_same_v<MmaType, cutlass::float_e5m2_t>;
|
||||
using FP8Sched = cute::conditional_t<size<0>(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>;
|
||||
using ScheduleRef = cute::conditional_t<IsFP8Input, FP8Sched, cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
|
||||
using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaType, LayoutA, AlignmentA,
|
||||
MmaType, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
ScheduleRef
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloopRef,
|
||||
CollectiveEpilogueRef
|
||||
>;
|
||||
|
||||
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
|
||||
using StrideA_verif = typename GemmRef::GemmKernel::StrideA;
|
||||
using StrideB_verif = typename GemmRef::GemmKernel::StrideB;
|
||||
using StrideC_verif = typename GemmRef::GemmKernel::StrideC;
|
||||
using StrideD_verif = typename GemmRef::GemmKernel::StrideD;
|
||||
|
||||
const ElementD epsilon(1e-2f);
|
||||
const ElementD non_zero_floor(1e-4f);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
// we don't swap and transpose in the verify so revert the problem shape.
|
||||
auto N = get<0>(problem);
|
||||
auto M = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
if (M == 0) {
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
StrideA_verif stride_A_verif;
|
||||
StrideB_verif stride_B_verif;
|
||||
|
||||
stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1));
|
||||
stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1));
|
||||
|
||||
int const scale_k = cutlass::ceil_div(options.k, options.c);
|
||||
auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i));
|
||||
auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i));
|
||||
cudaStream_t stream = cudaStreamDefault;
|
||||
cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.c, stream);
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
typename GemmRef::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif},
|
||||
{{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)}
|
||||
};
|
||||
|
||||
// Run the gemm where the scaling is performed outside of the kernel.
|
||||
GemmRef gemm_ref;
|
||||
size_t workspace_size = GemmRef::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm_ref.run());
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor);
|
||||
std::cout << "Group " << i << ": " << options.problem_sizes_host[i] << ", alpha: " << alpha_host[i] << ", beta: " << beta_host[i] << " Status: " << passed << std::endl;
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
std::cout << "We passed all checks\n";
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
MixedDtypeResult result;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host);
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) {
|
||||
std::cerr << "This example requires CUDA 12.3 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
|
||||
std::cout << "Running in no scale mode." << std::endl;
|
||||
run<GemmConvertOnly>(options, false);
|
||||
}
|
||||
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
|
||||
std::cout << "Running in group scale mode." << std::endl;
|
||||
run<GemmScaleOnly>(options, false);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
116
examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt
Normal file
116
examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=100 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=2.0 --beta=2.0 --groups=100 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=0.25 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=16 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=100 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=100 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10) # Random problem sizes
|
||||
|
||||
set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion
|
||||
|
||||
set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling
|
||||
set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --c=512 --mode=1 --iterations=0) # Group-wise scaling
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_mixed_dtype_grouped_gemm
|
||||
69_hopper_mixed_dtype_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_int4_fp8_grouped_gemm
|
||||
69_hopper_int4_fp8_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
69_hopper_int4_bf16_grouped_gemm
|
||||
69_hopper_int4_bf16_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
TEST_DIRECT_BATCHED
|
||||
TEST_SCALE_PERCOL
|
||||
TEST_SCALE_GROUP
|
||||
)
|
||||
46
examples/69_hopper_mixed_dtype_grouped_gemm/README.md
Normal file
46
examples/69_hopper_mixed_dtype_grouped_gemm/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
This example extends Example 55 to support Grouped GEMMs in CUTLASS.
|
||||
|
||||
## High level overview
|
||||
|
||||
This example shows how to perform Grouped GEMMs on Hopper when A and B have different types. In the Grouped GEMM, multiple GEMMs with potentially different problem shapes can be excetued in a batch. The interface is similar to the standard mixed-input GEMM presented in Example 55, with a few noteworthy differences:
|
||||
- inside the collective builder, replace the layout types with layout pointer types.
|
||||
- in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B.
|
||||
- if scales and zero-points are included, also pass the array of their strides in the arguments.
|
||||
|
||||
Note that in Example 55, the argument `--g` is used to determine the group size of scaling. To avoid confusion with the `--groups` argument in this example, which defines the number of GEMMs, `--c` is used here to represent the group size for scaling.
|
||||
|
||||
## Upcoming features
|
||||
|
||||
Currently, the Mixed-input Grouped GEMM only supports row-wise scaling, and group-wise scaling for identical problem shapes across all groups. Please contact us if zero-points or block-wise scaling are needed.
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
@ -0,0 +1,191 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp"
|
||||
|
||||
template<class QuantType>
|
||||
class GroupedMixedDtypeOptions : public MixedDtypeOptions {
|
||||
public:
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int,int,int>>;
|
||||
using UnderlyingProblemShape = typename ProblemShape::UnderlyingProblemShape;
|
||||
|
||||
int groups = 6;
|
||||
int c = 512;
|
||||
std::string benchmark_path;
|
||||
std::vector<UnderlyingProblemShape> problem_sizes_host;
|
||||
|
||||
GroupedMixedDtypeOptions() : MixedDtypeOptions()
|
||||
{
|
||||
m = 1024;
|
||||
n = 2048;
|
||||
k = 512;
|
||||
};
|
||||
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
MixedDtypeOptions::parse(argc, args);
|
||||
|
||||
problem_sizes_host = benchmark_path.empty() ? randomize_problems(cmd) : load_benchmark_problems();
|
||||
}
|
||||
|
||||
std::ostream& print_usage(std::ostream& out) const {
|
||||
out << "69_hopper_mixed_dtype_grouped_gemm\n\n"
|
||||
<< "Options:\n"
|
||||
<< " --help Display this usage statement\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --c=<int> Sets the chunk size for scaling the quantized weights\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems\n"
|
||||
<< " --mode=<int> The mode to run the gemm\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --iterations=<int> Number of profiling iterations\n"
|
||||
<< " --warmup=<int> Number of warmup iterations\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
double gflops(double runtime_s) const {
|
||||
uint64_t fmas = std::accumulate(problem_sizes_host.begin(), problem_sizes_host.end(), 0ULL,
|
||||
[](uint64_t sum, const UnderlyingProblemShape& problem) {
|
||||
return sum + static_cast<uint64_t>(cute::get<0>(problem)) *
|
||||
static_cast<uint64_t>(cute::get<1>(problem)) *
|
||||
static_cast<uint64_t>(cute::get<2>(problem));
|
||||
});
|
||||
return (2.0 * fmas) / (runtime_s * 1e9);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr int tma_alignment_bits = 128;
|
||||
const int alignment = tma_alignment_bits / cutlass::sizeof_bits<QuantType>::value;
|
||||
|
||||
std::vector<UnderlyingProblemShape> randomize_problems(cutlass::CommandLine& cmd) {
|
||||
std::vector<UnderlyingProblemShape> problems;
|
||||
problems.reserve(groups);
|
||||
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
int m = (cmd_line_m >= 0) ? cmd_line_m : alignment * ((rand() % 64) + 1);
|
||||
int n = (cmd_line_n >= 0) ? cmd_line_n : this->n;
|
||||
int k = (cmd_line_k >= 0) ? cmd_line_k : this->k;
|
||||
|
||||
if (k % alignment != 0) {
|
||||
throw std::runtime_error("Error: k dimension must be a multiple of " + std::to_string(alignment));
|
||||
}
|
||||
problems.push_back({m, n, k});
|
||||
}
|
||||
return problems;
|
||||
}
|
||||
|
||||
std::vector<UnderlyingProblemShape> load_benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file) {
|
||||
throw std::runtime_error("Failed to open benchmark file: " + benchmark_path);
|
||||
}
|
||||
|
||||
std::vector<UnderlyingProblemShape> problems;
|
||||
int idx;
|
||||
std::string extent_str;
|
||||
|
||||
while (file >> idx >> extent_str) {
|
||||
if (idx < 0 || extent_str.empty()) break;
|
||||
|
||||
std::vector<std::string> tokens;
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
for (int i = 0; i < std::min(3, static_cast<int>(tokens.size())); ++i) {
|
||||
int x = std::stoi(tokens[i]);
|
||||
extent.at(i) = (x % alignment) ? x + (alignment - (x % alignment)) : x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problems.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problems.size());
|
||||
return problems;
|
||||
}
|
||||
};
|
||||
|
||||
template <class QuantType, class Gemm, class ElementAccumulator>
|
||||
void grouped_mixed_dtype_profiling(
|
||||
Gemm& gemm,
|
||||
const GroupedMixedDtypeOptions<QuantType>& options,
|
||||
MixedDtypeResult& result,
|
||||
const std::vector<ElementAccumulator>& alpha_host,
|
||||
const std::vector<ElementAccumulator>& beta_host) {
|
||||
|
||||
if (options.iterations <= 0) return;
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
std::vector<float> runtimes;
|
||||
runtimes.reserve(options.iterations);
|
||||
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
cudaEventDestroy(start);
|
||||
cudaEventDestroy(stop);
|
||||
|
||||
result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size();
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
std::cout << " Groups : " << options.groups << '\n'
|
||||
<< " Avg runtime : " << result.avg_runtime_ms << " ms\n"
|
||||
<< " GFLOPS : " << result.gflops << '\n';
|
||||
}
|
||||
485
examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu
Normal file
485
examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu
Normal file
@ -0,0 +1,485 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A FP16 dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
||||
|
||||
This example demonstrates minimal set of changes needed to transition from a Hopper CUTLASS 3.x
|
||||
GEMM kernel (see example 48_hopper_warp_specialized_gemm) to a Blackwell 3.x CUTLASS GEMM kernel.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features:
|
||||
|
||||
1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a)
|
||||
which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA).
|
||||
|
||||
Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a).
|
||||
Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the
|
||||
Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
$ ./examples/70_blackwell_gemm/70_blackwell_fp16_gemm --m=8192 --n=8192 --k=8192
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(8192), n(8192), k(8192),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "70_blackwell_fp16_gemm\n\n"
|
||||
<< " Blackwell FP16 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "70_blackwell_fp16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 100a.
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
672
examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu
Normal file
672
examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu
Normal file
@ -0,0 +1,672 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A FP8 dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
||||
|
||||
This example demonstrates minimal set of changes needed to transition from a Hopper CUTLASS 3.x
|
||||
FP8 GEMM kernel (see example 54_hopper_fp8_warp_specialized_gemm) to a Blackwell SM100 FP8 GEMM kernel.
|
||||
|
||||
This example shows all important fusions used by FP8 gemm kernels,
|
||||
i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features:
|
||||
|
||||
1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a)
|
||||
which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA).
|
||||
|
||||
Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a).
|
||||
Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the
|
||||
Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
$ ./examples/70_blackwell_gemm/70_blackwell_fp8_gemm --m=8192 --n=8192 --k=8192
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementC;
|
||||
using LayoutAux = LayoutC;
|
||||
using ElementAmax = float;
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using FusionOp = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
LayoutC, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutC, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
FusionOp
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideAux = StrideC;
|
||||
|
||||
constexpr bool IsDFp8 =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsAuxFp8 =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
StrideAux stride_aux;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_B;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_C;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_D;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
|
||||
bool device_scale = false;
|
||||
bool save_aux = true;
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
int swizzle = 0;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("scale_a", scale_a, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_b", scale_b, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_c", scale_c, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_d", scale_d, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f);
|
||||
cmd.get_cmd_line_argument("device_scale", device_scale, false);
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "70_blackwell_fp8_gemm\n\n"
|
||||
<< " Blackwell FP8 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --scale_a=<f32> Scaling factor for A\n"
|
||||
<< " --scale_b=<f32> Scaling factor for B\n"
|
||||
<< " --scale_c=<f32> Scaling factor for C\n"
|
||||
<< " --scale_d=<f32> Scaling factor for D (ignored for non-fp8 D)\n"
|
||||
<< " --scale_aux=<f32> Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n"
|
||||
<< " --device_scale=<bool> Copy scalars to device memory before kernel launch (default: false)\n"
|
||||
<< " --save_aux=<bool> Save the pre-activation as an auxiliary tensor (default: true)\n"
|
||||
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "70_blackwell_fp8_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_aux = stride_D;
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2024);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.resize(c_coord);
|
||||
tensor_aux.sync_device();
|
||||
tensor_ref_aux.resize(c_coord);
|
||||
}
|
||||
|
||||
if (options.device_scale) {
|
||||
scalar_alpha.resize(cutlass::make_Coord(1));
|
||||
scalar_beta.resize(cutlass::make_Coord(1));
|
||||
scale_A.resize(cutlass::make_Coord(1));
|
||||
scale_B.resize(cutlass::make_Coord(1));
|
||||
scale_C.resize(cutlass::make_Coord(1));
|
||||
scale_D.resize(cutlass::make_Coord(1));
|
||||
scale_aux.resize(cutlass::make_Coord(1));
|
||||
|
||||
cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha);
|
||||
cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta);
|
||||
cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a);
|
||||
cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b);
|
||||
cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c);
|
||||
cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d);
|
||||
cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux);
|
||||
|
||||
scalar_alpha.sync_device();
|
||||
scalar_beta.sync_device();
|
||||
scale_A.sync_device();
|
||||
scale_B.sync_device();
|
||||
scale_C.sync_device();
|
||||
scale_D.sync_device();
|
||||
scale_aux.sync_device();
|
||||
}
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.resize(cutlass::make_Coord(1));
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(cutlass::make_Coord(1));
|
||||
}
|
||||
|
||||
if (IsAuxFp8 && options.save_aux && options.save_amax) {
|
||||
abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
abs_max_aux.sync_device();
|
||||
reference_abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = scalar_alpha.device_data();
|
||||
fusion_args.beta_ptr = scalar_beta.device_data();
|
||||
fusion_args.scale_a = options.scale_a;
|
||||
fusion_args.scale_b = options.scale_b;
|
||||
fusion_args.scale_c = options.scale_c;
|
||||
fusion_args.scale_a_ptr = scale_A.device_data();
|
||||
fusion_args.scale_b_ptr = scale_B.device_data();
|
||||
fusion_args.scale_c_ptr = scale_C.device_data();
|
||||
|
||||
// ignored if tensor types are not fp8
|
||||
fusion_args.scale_d = options.scale_d;
|
||||
fusion_args.scale_aux = options.scale_aux;
|
||||
fusion_args.scale_d_ptr = scale_D.device_data();
|
||||
fusion_args.scale_aux_ptr = scale_aux.device_data();
|
||||
|
||||
// leaving/setting these as nullptr disables the fusion at runtime
|
||||
fusion_args.bias_ptr = nullptr;
|
||||
|
||||
if (options.save_aux) {
|
||||
fusion_args.aux_ptr = tensor_aux.device_data();
|
||||
fusion_args.dAux = stride_aux;
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_aux_ptr = abs_max_aux.device_data();
|
||||
}
|
||||
}
|
||||
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
auto Aux = cute::make_tensor(tensor_ref_aux.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_aux));
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
decltype(Aux),
|
||||
unused_t, // valpha
|
||||
unused_t, // vbeta
|
||||
ActivationFunctor
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.Aux = Aux;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
epilogue_params.scale_a = options.scale_a;
|
||||
epilogue_params.scale_b = options.scale_b;
|
||||
epilogue_params.scale_c = options.scale_c;
|
||||
epilogue_params.scale_d = options.scale_d;
|
||||
epilogue_params.scale_aux = options.scale_aux;
|
||||
epilogue_params.abs_max_D = reference_abs_max_D.host_data();
|
||||
epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data();
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
}
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least sm100a.
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Run
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
56
examples/70_blackwell_gemm/CMakeLists.txt
Normal file
56
examples/70_blackwell_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,56 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(TEST_SWIZZLE_1 --swizzle=1)
|
||||
set(TEST_SWIZZLE_2 --swizzle=2)
|
||||
set(TEST_SWIZZLE_5 --swizzle=5)
|
||||
set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384)
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp16_gemm
|
||||
70_blackwell_fp16_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_SWIZZLE_1
|
||||
TEST_SWIZZLE_2
|
||||
TEST_SWIZZLE_5
|
||||
TEST_SWIZZLE_5_UNEVEN
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp8_gemm
|
||||
70_blackwell_fp8_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_SWIZZLE_1
|
||||
TEST_SWIZZLE_2
|
||||
TEST_SWIZZLE_5
|
||||
TEST_SWIZZLE_5_UNEVEN
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,572 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Blackwell SM100 GEMM example demonstrating compatible mainloop+epilogue builder schedules
|
||||
and epilogue visitor tree (EVT) construction
|
||||
|
||||
Example usage:
|
||||
$ ./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder \
|
||||
--m=2048 --n=2048 --k=2048 --l=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
|
||||
int m, n, k, l;
|
||||
float alpha, beta;
|
||||
int swizzle;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "71_blackwell_gemm_with_collective_builder\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " performant kernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
|
||||
// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the
|
||||
// number of pipeline stages.
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
|
||||
// Type of epilogue schedule to generate
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
// Number of pipeline stages to use
|
||||
class StageCountType = cutlass::gemm::collective::StageCountAuto,
|
||||
// Do we use custom epilogue visitor tree (EVT) fusion
|
||||
bool UseCustomEVT = false
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
using ClusterShapeMNK = Shape<_2,_2,_1>;
|
||||
static constexpr bool Use2SmMma =
|
||||
// Manually specified 2sm cluster MMA schedule, will error if cluster M is not a multiple of 2
|
||||
std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelTmaWarpSpecialized2SmSm100> ||
|
||||
// Auto schedule will try to select 2sm cluster MMA based on cluster M
|
||||
std::is_same_v<MainloopScheduleType, cutlass::gemm::collective::KernelScheduleAuto> && size<0>(ClusterShapeMNK{}) % 2 == 0;
|
||||
// The MMA tile used by the mainloop collective. Blackwell 1sm MMA supports up to MMA tile M = 128, 2sm MMA supports up to MMA tile M = 256
|
||||
using MmaTileMNK = std::conditional_t<Use2SmMma, Shape<_256,_128,_64>, Shape<_128,_128,_64>>;
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
|
||||
// Blackwell fusions for the most part use the same EVT nodes used in Hopper. Most Blackwell EVTs will alias to their Hopper counterparts.
|
||||
// EVT nodes new to Blackwell mainly relate to narrow precision scale factor generation and are contained in include/cutlass/epilogue/fusion/sm100_visitor_*.hpp
|
||||
// See include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp for EVT construction using these new nodes
|
||||
// Fusions relating to narrow-precision scale factor generation are demonstrated in example 72b and can only be used in blackwell kernels
|
||||
using CustomEVT = // alpha * acc + beta * C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
cutlass::epilogue::fusion::Sm90SrcFetch<ElementC>, // C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
cutlass::epilogue::fusion::Sm90AccFetch // acc
|
||||
>
|
||||
>;
|
||||
|
||||
// As in Hopper, a predefined set of fusion operations are provided in include/cutlass/epilogue/fusion/operations.hpp and can be passed to the epilogue builder
|
||||
// Fusions operations supported by the Hopper TMA epilogue will also be supported by the Blackwell TMA epilogue
|
||||
// Fusions relating to narrow-precision scale factor generation are demonstrated in example 72b and can only be used in blackwell kernels
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
cute::conditional_t<UseCustomEVT, CustomEVT, DefaultOperation>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cute::conditional_t<cute::is_same_v<StageCountType, cutlass::gemm::collective::StageCountAuto>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
StageCountType>,
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
|
||||
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
|
||||
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
|
||||
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(const ProblemShapeType& problem_size, float alpha, float beta) {
|
||||
auto [M, N, K, L] = problem_size;
|
||||
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));
|
||||
|
||||
cutlass::reference::device::GemmComplex(
|
||||
{M, N, K},
|
||||
ElementScalar(alpha),
|
||||
ref_A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ref_B,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementScalar(beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
ElementAccumulator(0),
|
||||
L, // batch_count
|
||||
M * K, // batch_stride_A
|
||||
K * N, // batch_stride_B
|
||||
M * N, // batch_stride_C
|
||||
M * N // batch_stride_D
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const ProblemShapeType& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(K * N * L);
|
||||
block_C.reset(M * N * L);
|
||||
block_D.reset(M * N * L);
|
||||
block_ref_D.reset(M * N * L);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{}, // epilogue.thread
|
||||
block_C.get(), stride_C, block_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
|
||||
// See example 48 for details on custom EVT construction
|
||||
if constexpr (UseCustomEVT) {
|
||||
arguments.epilogue.thread =
|
||||
{ // ternary op : beta * C + (alpha * acc)
|
||||
{{options.beta}}, // leaf op+args : beta
|
||||
{}, // leaf op+args : C
|
||||
{ // binary op : alpha * acc
|
||||
{{options.alpha}}, // leaf op+args : alpha
|
||||
{}, // leaf op+args : acc
|
||||
{} // binary args : multiplies
|
||||
}, // end binary op
|
||||
{} // ternary args : multiply_add
|
||||
}; // end ternary op
|
||||
}
|
||||
// Pre-defined fusions will have flat, named args for user-friendlyness
|
||||
else {
|
||||
arguments.epilogue.thread.alpha = options.alpha;
|
||||
arguments.epilogue.thread.beta = options.beta;
|
||||
}
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, options.alpha, options.beta);
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, bool passed) {
|
||||
std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
bool passed;
|
||||
|
||||
// Auto mainloop and epilogue schedules must be used together to guarantee functionality
|
||||
ExampleRunner<> runner_0;
|
||||
passed = runner_0.run(options, hw_info);
|
||||
print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule", passed);
|
||||
|
||||
// Mainloop stage counts can be specified manually
|
||||
// It is the user's responsibility to ensure there is enough device smem to allocate manual stage counts
|
||||
ExampleRunner<
|
||||
cutlass::gemm::collective::KernelScheduleAuto,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
_3> runner_1;
|
||||
passed = runner_1.run(options, hw_info);
|
||||
print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and 3 mainloop stages", passed);
|
||||
|
||||
// 1SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmSm100, cutlass::epilogue::NoSmemWarpSpecialized1Sm> runner_2;
|
||||
passed = runner_2.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized1Sm epilogue schedule", passed);
|
||||
|
||||
// 1SM cluster MMA mainloop schedules can also be used with 1SM TMA epilogue schedules
|
||||
// 1SM cluster MMA mainloop schedules will not work with 2SM TMA epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmSm100, cutlass::epilogue::TmaWarpSpecialized1Sm> runner_3;
|
||||
passed = runner_3.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with TmaWarpSpecialized1Sm epilogue schedule", passed);
|
||||
|
||||
// 2SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized2SmSm100, cutlass::epilogue::NoSmemWarpSpecialized2Sm> runner_4;
|
||||
passed = runner_4.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with NoSmemWarpSpecialized2Sm epilogue schedule", passed);
|
||||
|
||||
// 2SM cluster MMA mainloop schedules can also be used with 2SM TMA epilogue schedules
|
||||
// 2SM cluster MMA mainloop schedules will not work with SM TMA epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized2SmSm100, cutlass::epilogue::TmaWarpSpecialized2Sm> runner_5;
|
||||
passed = runner_5.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with TmaWarpSpecialized2Sm epilogue schedule", passed);
|
||||
|
||||
// Blackwell Auto schedule supports custom EVT fusions
|
||||
constexpr bool UseCustomEVT = true;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::collective::KernelScheduleAuto,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_6;
|
||||
passed = runner_6.run(options, hw_info);
|
||||
print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and custom EVT", passed);
|
||||
|
||||
// 1SM TMA epilogue schedules support custom EVT fusions
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100,
|
||||
cutlass::epilogue::TmaWarpSpecialized1Sm,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_7;
|
||||
passed = runner_7.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with TmaWarpSpecialized1Sm epilogue and custom EVT", passed);
|
||||
|
||||
// 2SM TMA epilogue schedules support custom EVT fusions
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecialized2SmSm100,
|
||||
cutlass::epilogue::TmaWarpSpecialized2Sm,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_8;
|
||||
passed = runner_8.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with TmaWarpSpecialized2Sm epilogue and custom EVT", passed);
|
||||
|
||||
|
||||
// Blackwell direct store epilogue schedule supports custom EVTs and named fusion operations as well (not supported for pre-Blackwell kernels)
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_9;
|
||||
passed = runner_9.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized1Sm epilogue and custom EVT", passed);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
71_blackwell_gemm_with_collective_builder.cu
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,548 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
|
||||
on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
|
||||
and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Similar to 70_blackwell_gemm, this kernel leverages:
|
||||
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "72a_blackwell_nvfp4_bf16_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,603 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM100 architecture
|
||||
on NVIDIA Blackwell SM100 architecture. The kernel outputs quantized fp4 values with scale factors that be the input of another GEMM.
|
||||
|
||||
Similar to 72a_blackwell_nvfp4_bf16_gemm, this kernel leverages:
|
||||
1. Blockscaled tcgen05.mma instructions.
|
||||
|
||||
2. Per-SM memory called Tensor Memory (TMEM)
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
|
||||
using ElementSFD = cutlass::float_ue8m0_t; // Element type for SFB matrix operand
|
||||
using ElementC = float; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
|
||||
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_128,_128,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
constexpr int InputSFVectorSize = 16;
|
||||
constexpr int OutputSFVectorSize = InputSFVectorSize;
|
||||
|
||||
// D = alpha * acc + beta * C
|
||||
// With BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, LayoutSFDTag,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
|
||||
using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
LayoutSFD layout_SFD;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensors
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
|
||||
// Reference Output Tensors
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
|
||||
// Matrix-wide normalization constant
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "72b_blackwell_nvfp4_nvfp4_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
// For SFD tensor layout
|
||||
using Sm1xxBlockScaledOutputConfig= typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
block_Normconst.reset(cutlass::make_Coord(1));
|
||||
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
block_Normconst.at(cutlass::make_Coord(0)) = 2;
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_D.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
block_SFD.sync_device();
|
||||
block_Normconst.sync_device();
|
||||
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{ options.alpha, options.beta },
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D}
|
||||
};
|
||||
|
||||
if constexpr (IsBlockScaleSupported) {
|
||||
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
}
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
// think about how to simplify the gemm3x interface.
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementCompute, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D), // TensorD
|
||||
decltype(tensor_SFD), // TensorSfD
|
||||
cute::Int<OutputSFVectorSize>,
|
||||
cutlass::reference::host::SfStrategy::SfDGen
|
||||
> epilogue_params {options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
block_SFD.sync_host();
|
||||
bool passed_sfd = cutlass::reference::host::TensorEquals(block_reference_SFD.host_view(), block_SFD.host_view());
|
||||
passed_sfd &= (cutlass::reference::host::TensorNorm(block_reference_SFD.host_view()) > 0);
|
||||
passed_sfd &= (cutlass::reference::host::TensorNorm(block_SFD.host_view()) > 0);
|
||||
|
||||
return passed && passed_sfd;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,549 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a mixed precision blockscaled GEMM on the NVIDIA Blackwell SM100 architecture.
|
||||
This Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
|
||||
on the Blackwell architecture (sm100a) which have the same throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
|
||||
and 2x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Similar to 72a_blackwell_nvfp4_fp32_gemm, this kernel leverages:
|
||||
1. Blockscaled tcgen05.mma instructions.
|
||||
|
||||
2. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "72c_blackwell_mixed_mxfp8_bf16_gemm\n\n"
|
||||
<< " Blackwell Mxfp8 x Mxfp4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
46
examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt
Normal file
46
examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,46 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
72a_blackwell_nvfp4_bf16_gemm
|
||||
72a_blackwell_nvfp4_bf16_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
72b_blackwell_nvfp4_nvfp4_gemm
|
||||
72b_blackwell_nvfp4_nvfp4_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
72c_blackwell_mixed_mxfp8_bf16_gemm
|
||||
72c_blackwell_mixed_mxfp8_bf16_gemm.cu
|
||||
)
|
||||
endif()
|
||||
36
examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt
Normal file
36
examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
73_blackwell_gemm_preferred_cluster
|
||||
blackwell_gemm_preferred_cluster.cu
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,541 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture with preferred cluster.
|
||||
|
||||
With the introduction of NVIDIA Compute Capability 9.0, the CUDA programming model introduced
|
||||
an optional hierarchy level known as Thread Block Clusters, which consist of multiple Thread Blocks.
|
||||
While the CUDA programming model has supported the specification of cluster shapes at runtime
|
||||
(Dynamic Clusters) since the Hopper architecture, CUTLASS has only provided support for Static
|
||||
Clusters, meaning that cluster shapes must be defined at compile time.
|
||||
|
||||
Larger cluster shapes can achieve higher TMA multicast but may result in poor SM occupancy due
|
||||
to quantization. For instance, a 2x2 cluster on an 18 SM GPU would only utilize 16 SMs, leaving
|
||||
2 SMs idle.
|
||||
|
||||
Starting with Compute Capability 10.0, the CUDA programming model adds the ability to specify
|
||||
two clusters: preferred cluster and fallback cluster. For brevity, we refer to this as
|
||||
Preferred Clusters. In the previous example, users can now launch an additional 2x1 cluster to
|
||||
utilize the 2 idle SMs.
|
||||
|
||||
With CUTLASS 3.8, in addition to Dynamic Clusters, CUTLASS adds support for Preferred Dynamic Cluster,
|
||||
the ability for users to specify two clusters shapes at runtime.
|
||||
|
||||
Terminology
|
||||
* Static cluster: cluster shape is specified at compile time.
|
||||
* Dynamic cluster: cluster shape is specified at runtime and set by the host.
|
||||
* Preferred cluster: Kernel can be launched with two cluster shapes (preferred and fallback).
|
||||
|
||||
Preferred and fallback cluster shapes are subject to several constraints.
|
||||
* Preferred cluster depth (Z dimension) must be the same as that of fallback cluster.
|
||||
* Fallback cluster shape must evenly divide the preferred cluster shape.
|
||||
* Preferred cluster shape must evenly divide the kernel launch grid shape.
|
||||
|
||||
This example demonstrates how to use the Dynamic Clusters and Preferred Clusters features in
|
||||
CUTLASS 3.x Blackwell SM100 kernels. Users can specify preferred and fallback cluster shapes via GEMM arguments.
|
||||
|
||||
# Example:
|
||||
./73_blackwell_gemm_preferred_cluster" --m=4096 --n=4096 --k=4096 --preferred_cluster_m=4 --preferred_cluster_n=4 --fallback_cluster_m=2 --fallback_cluster_m=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the cluster set to <int,int,_1> to indicate dynamic cluster shape
|
||||
using ClusterShape_MNK = Shape<int,int,_1>;
|
||||
// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that
|
||||
// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2
|
||||
// To use tcgen05 MMA cta_group = 2, users must explicitly use 2sm builder schedules
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void // <--- Default to cluster launch control (CLC) scheduler
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
|
||||
int swizzle = 0;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(4096), n(4096), k(4096),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
preferred_cluster_m(4),
|
||||
preferred_cluster_n(4),
|
||||
fallback_cluster_m(2),
|
||||
fallback_cluster_n(1),
|
||||
swizzle(0)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
|
||||
cmd.get_cmd_line_argument("swizzle", swizzle);
|
||||
|
||||
if (!validate_cluster_shape()){
|
||||
std::cout << "--Invalid cluster shapes" << std::endl;
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "73_blackwell_gemm_preferred_cluster\n\n"
|
||||
<< " Blackwell FP16 GEMM using preferred cluster.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --swizzle=<int> Cluster rasterization swizzle\n"
|
||||
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
|
||||
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
|
||||
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
|
||||
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "Preferred cluster shape cannot be smaller than fallback cluster shape.\n"
|
||||
<< "Preferred cluster shape must be a multiple of fallback cluster shape.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "73_blackwell_gemm_preferred_cluster" << " --m=4096 --n=4096 --k=4096 --preferred_cluster_m=4 --preferred_cluster_n=4 --fallback_cluster_m=2 --fallback_cluster_m=1\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Validate preferred and fallback cluster shapes
|
||||
bool validate_cluster_shape() {
|
||||
if (preferred_cluster_m < fallback_cluster_m || preferred_cluster_n < fallback_cluster_n) {
|
||||
std::cout << "--Preferred cluster cannot be smaller than fallback cluster" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (preferred_cluster_m % fallback_cluster_m != 0 || preferred_cluster_n % fallback_cluster_n != 0) {
|
||||
std::cout << "--Preferred cluster must be a multiple of fallback cluster" << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(cutlass::DeviceAllocation<Element>& block, uint64_t seed=2023) {
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
|
||||
|
||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
int run(Options &options) {
|
||||
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << "GEMM with"
|
||||
<< " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k
|
||||
<< " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)"
|
||||
<< " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)"
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "--------------------------------------------------------------------------------" << std::endl;
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
37
examples/74_blackwell_gemm_streamk/CMakeLists.txt
Normal file
37
examples/74_blackwell_gemm_streamk/CMakeLists.txt
Normal file
@ -0,0 +1,37 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
74_blackwell_gemm_streamk
|
||||
blackwell_gemm_streamk.cu
|
||||
)
|
||||
endif()
|
||||
587
examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu
Normal file
587
examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu
Normal file
@ -0,0 +1,587 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture with the Stream-K scheduler.
|
||||
|
||||
Stream-K is a GEMM parallelization technique that attempts to reduce load imbalance across SMs
|
||||
by parallelizing certain output tiles across the K mode of the GEMM, without using a static splitting factor.
|
||||
For complete details on Stream-K, please see https://arxiv.org/abs/2301.03598.
|
||||
|
||||
CUTLASS's Stream-K scheduler using the CUTLASS 3.x API is capable of supporting various modes of
|
||||
decomposing a GEMM (referred to as "decomposition modes" in this example):
|
||||
* DataParallel: basic GEMM parallelized spatially via tiling, but without splitting the K mode
|
||||
* SplitK: `split_factor` CTAs compute portions of the K mode for a given output tile and reduce their results
|
||||
* StreamK: parallelizes work according to the stream-K load balancing method described in https://arxiv.org/abs/2301.03598
|
||||
* Heuristic: applies an internal heuristic in attempt to choose the most performant among the three preceding decomposition modes
|
||||
|
||||
Additionally, the Stream-K scheduler supports two different means of performing reductions for
|
||||
decomposition modes that require reduction (SplitK, StreamK, and Heuristic):
|
||||
* Deterministic: Participating CTAs perform reduction in a turnstile fashion in order of the K mode
|
||||
covered by each CTA. This requires a lock to be held exclusively by the CTA that is
|
||||
currently accumulating.
|
||||
* Nondeterministic: Participating CTAs perform reduction atomically to the same workspace (mostly) without locking.
|
||||
Locks are used only to wait for the first CTA to write its partial values (to initialize the
|
||||
workspace), and for all but the final CTA to have accumulated (so that the final CTA can load
|
||||
the accumulated value and accumulate it into registers on top of which the epilogue will
|
||||
be performed). Due to the nondeterminsitic ordering of accumulation, deterministic numeric
|
||||
behavior cannot be guaranteed with this mode (e.g., floating-point rounding error will depend
|
||||
on the order of accumulation)
|
||||
|
||||
This example allows one to try out different decomposition modes, reduction modes, and (when using Split-K) splitting factors.
|
||||
Here are a few examples of usage:
|
||||
# Heuristic mode with deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic
|
||||
|
||||
# Stream-K mode with determinsitic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic
|
||||
|
||||
# Split-K mode with a splitting factor of 2 and deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=SplitK --reduction=Deterministic --splits=2
|
||||
|
||||
# Stream-K mode with nondeterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Nondeterministic
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the cluster set to <int,int,_1> to indicate dynamic cluster shape
|
||||
using ClusterShape_MNK = Shape<int,int,_1>;
|
||||
// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that
|
||||
// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2
|
||||
// To use tcgen05 MMA cta_group = 2, users must explicitly use 2sm builder schedules
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::StreamKScheduler // <--- Change needed to enable the stream-K scheduler
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
|
||||
using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
DecompositionMode decomposition_mode;
|
||||
ReductionMode reduction_mode;
|
||||
int splits;
|
||||
|
||||
std::unordered_map<DecompositionMode, std::vector<std::string>> dec_mappings = {
|
||||
{DecompositionMode::Heuristic, {"Heuristic", "heuristic", "h", "H", ""}},
|
||||
{DecompositionMode::SplitK, {"SplitK", "split-k", "split-K", "Split-K", "Split-k", "splitk", "Splitk", "splitK", "spk", "SpK", "spK"}},
|
||||
{DecompositionMode::StreamK, {"StreamK", "stream-k", "stream-K", "Stream-K", "Stream-k", "streamk", "Streamk", "streamK", "stk", "StK", "stK"}},
|
||||
{DecompositionMode::DataParallel, {"DataParallel", "data-parallel", "dataparallel", "dp", "DP"}}
|
||||
};
|
||||
|
||||
std::unordered_map<ReductionMode, std::vector<std::string>> red_mappings = {
|
||||
{ReductionMode::Deterministic, {"Deterministic", "deterministic", "d", "D", ""}},
|
||||
{ReductionMode::Nondeterministic, {"Nondeterministic", "nondeterministic", "n", "N"}}
|
||||
};
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(256), n(256), k(16384),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
preferred_cluster_m(4),
|
||||
preferred_cluster_n(4),
|
||||
fallback_cluster_m(2),
|
||||
fallback_cluster_n(1),
|
||||
decomposition_mode(DecompositionMode::Heuristic),
|
||||
reduction_mode(ReductionMode::Deterministic),
|
||||
splits(1)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("splits", splits, 1);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
|
||||
|
||||
// Parse decompsition mode
|
||||
std::string decomp_mode;
|
||||
cmd.get_cmd_line_argument("decomposition", decomp_mode);
|
||||
bool found = parse_from_options_map(decomp_mode, dec_mappings, decomposition_mode);
|
||||
if (!found) {
|
||||
std::cout << "--decomposition must be one of Heuristic, SplitK, StreamK, or DataParallel" << std::endl;
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse reduction mode
|
||||
std::string red_mode;
|
||||
cmd.get_cmd_line_argument("reduction", red_mode);
|
||||
found = parse_from_options_map(red_mode, red_mappings, reduction_mode);
|
||||
if (!found) {
|
||||
std::cout << "--reduction must be one of Deterministic and Nondeterministic" << std::endl;
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "74_blackwell_gemm_streamk\n\n"
|
||||
<< " Blackwell FP16 GEMM using a stream-K kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
|
||||
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
|
||||
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
|
||||
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
|
||||
<< " --decomposition=<str> Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n"
|
||||
<< " --reduction=<str> Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "74_blackwell_gemm_streamk" << " --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
std::string decomposition_mode_str() const {
|
||||
return dec_mappings.at(decomposition_mode).at(0);
|
||||
}
|
||||
|
||||
std::string reduction_mode_str() const {
|
||||
return red_mappings.at(reduction_mode).at(0);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
bool parse_from_options_map(std::string val, std::unordered_map<T, std::vector<std::string>> options, T& result) const {
|
||||
for (const auto & [key, values] : options) {
|
||||
if (std::find(values.begin(), values.end(), val) != values.end()) {
|
||||
result = key;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(cutlass::DeviceAllocation<Element>& block, uint64_t seed=2023) {
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
|
||||
|
||||
arguments.scheduler.splits = options.splits;
|
||||
arguments.scheduler.decomposition_mode = options.decomposition_mode;
|
||||
arguments.scheduler.reduction_mode = options.reduction_mode;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
int run(Options &options) {
|
||||
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << "Stream-K GEMM with"
|
||||
<< " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k
|
||||
<< " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)"
|
||||
<< " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)\n"
|
||||
<< " Decomposition_mode=" << options.decomposition_mode_str()
|
||||
<< " Split_count=" << options.splits
|
||||
<< " Reduction_mode=" << options.reduction_mode_str()
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "--------------------------------------------------------------------------------" << std::endl;
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
811
examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu
Normal file
811
examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu
Normal file
@ -0,0 +1,811 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM100 TensorOp-based warp-specialized kernel.
|
||||
For this example all scheduling work is performed on the device.
|
||||
The new feature showcased in this example is device-side modification of TMA descriptors
|
||||
to move between groups/problem_count (represented by groups).
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
|
||||
To run this example for a set of problems using the benchmark option:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
|
||||
// Runtime Cluster Shape
|
||||
using ClusterShape = Shape<int32_t,int32_t,_1>;
|
||||
|
||||
// Different configs for 1SM and 2SM MMA kernel
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128,_256,Int<128 / sizeof(ElementA)>>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
|
||||
};
|
||||
|
||||
struct MMA2SMConfig {
|
||||
using MmaTileShape = Shape<_256,_256,Int<128 / sizeof(ElementA)>>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
|
||||
};
|
||||
|
||||
template <typename ScheduleConfig>
|
||||
struct GivenGemmSchedule {
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
typename ScheduleConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
typename ScheduleConfig::EpilogueSchedule,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename ScheduleConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
using GemmKernel1SM = GivenGemmSchedule<MMA1SMConfig>::GemmKernel;
|
||||
using Gemm1SM = GivenGemmSchedule<MMA1SMConfig>::Gemm;
|
||||
using Gemm = Gemm1SM;
|
||||
|
||||
using GemmKernel2SM = GivenGemmSchedule<MMA2SMConfig>::GemmKernel;
|
||||
using Gemm2SM = GivenGemmSchedule<MMA2SMConfig>::Gemm;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool use_pdl = false;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
int iterations = 10;
|
||||
int m = 1024, n = 2048, k = 512, groups = 10;
|
||||
dim3 cluster_shape = dim3(4,2,1);
|
||||
dim3 cluster_shape_fallback = dim3(2,1,1);
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::AlongM;
|
||||
int max_sm_count = INT_MAX;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("use_pdl")) {
|
||||
use_pdl = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
cmd.get_cmd_line_argument("cluster_m", cluster_shape.x);
|
||||
cmd.get_cmd_line_argument("cluster_n", cluster_shape.y);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y);
|
||||
cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster_order = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster_order = RasterOrderOptions::AlongM;
|
||||
}
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_host.reserve(groups);
|
||||
|
||||
for (int i = groups; i > 0; i--) {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "75_blackwell_grouped_gemm\n\n"
|
||||
<< " Blackwell FP8 Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --cluster_m=<int> and --cluster_n=<int> Sets the X,Y dims of the preferred cluster shape\n"
|
||||
<< " --cluster_fallback_m=<int> and --cluster_fallback_n=<int> Sets the X,Y dims of the fallback cluster shape\n\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
|
||||
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "75_blackwell_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = uint64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes_host) {
|
||||
fmas += static_cast<uint64_t>(get<0>(problem)) *
|
||||
static_cast<uint64_t>(get<1>(problem)) *
|
||||
static_cast<uint64_t>(get<2>(problem));
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(-2);
|
||||
} else {
|
||||
scope_max = static_cast<Element>(8);
|
||||
scope_min = static_cast<Element>(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(const Options &options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count);
|
||||
|
||||
if (!is_static_v<ClusterShape>) {
|
||||
if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 &&
|
||||
(options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) {
|
||||
std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl;
|
||||
}
|
||||
hw_info.cluster_shape = options.cluster_shape;
|
||||
hw_info.cluster_shape_fallback = options.cluster_shape_fallback;
|
||||
}
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
// If alpha/beta are provided (via cmd line args) and are scalar, then same alpha/beta applies to all batches.
|
||||
// If pointers to alpha/beta are provided, then alpha/beta can differ between batches/groups.
|
||||
if (options.alpha != FLT_MAX){
|
||||
// Single alpha for all groups
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
// Only one alpha per each group
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
}
|
||||
if (options.beta != FLT_MAX) {
|
||||
// Single beta for all groups
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dBeta = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// Only one beta per each group
|
||||
fusion_args.dBeta = {_0{}, _0{}, 1};
|
||||
}
|
||||
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = options.raster_order;
|
||||
|
||||
if (host_problem_shapes_available) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
bool passed = true;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{M, N, K},
|
||||
ElementAccumulator(alpha_host.at(i)),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(beta_host.at(i)),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N);
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 ||
|
||||
((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)
|
||||
)
|
||||
) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running kernel with 1SM MMA config:" << std::endl;
|
||||
run<Gemm1SM>(options, false /*host_problem_shapes_available*/);
|
||||
std::cout << "Running kernel with 2SM MMA config:" << std::endl;
|
||||
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,951 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM100 TensorOp-based warp-specialized kernel
|
||||
for narrow precisions (FP4) with Scale Factors (In and Out).
|
||||
For this example all scheduling work is performed on the device.
|
||||
The new feature showcased in this example is device-side modification of TMA descriptors
|
||||
to move between groups/problem_count (represented by groups).
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm_block_scaled/75_blackwell_grouped_gemm_block_scaled --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
|
||||
To run this example for a set of problems using the benchmark option:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm_block_scaled/75_blackwell_grouped_gemm_block_scaled --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
|
||||
using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands
|
||||
using ElementC = cutlass::half_t; // Element type for C matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<ElementInput>; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<ElementInput>; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = ElementC; // Element type for D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Alignment of D matrix in units of elements (up to 16 bytes)
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
|
||||
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
|
||||
using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands
|
||||
constexpr int OutputSFVectorSize = 16;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementAccumulator,
|
||||
ElementSFD,
|
||||
LayoutC,
|
||||
ElementC>;
|
||||
|
||||
// Core kernel configurations
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
|
||||
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
|
||||
// Runtime Cluster Shape
|
||||
using ClusterShape = Shape<int32_t,int32_t,_1>;
|
||||
|
||||
// Different configs for 1SM and 2SM MMA kernel
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128,_256,_256>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
|
||||
};
|
||||
|
||||
struct MMA2SMConfig {
|
||||
using MmaTileShape = Shape<_256,_256,_256>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
|
||||
};
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
typename MMA1SMConfig::MmaTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutC *, AlignmentD,
|
||||
typename MMA1SMConfig::EpilogueSchedule
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA1SMConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMA1SMConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = Gemm1SM;
|
||||
|
||||
using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
typename MMA2SMConfig::MmaTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutC *, AlignmentD,
|
||||
typename MMA2SMConfig::EpilogueSchedule
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA2SMConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMA2SMConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop2SM,
|
||||
CollectiveEpilogue2SM
|
||||
>;
|
||||
using Gemm2SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel2SM>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig<
|
||||
OutputSFVectorSize,
|
||||
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
|
||||
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
|
||||
>;
|
||||
using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom;
|
||||
using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFA> layout_SFB_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
using HostTensorA = cutlass::HostTensor<typename Gemm::ElementA, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorB = cutlass::HostTensor<typename Gemm::ElementB, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorSF = cutlass::HostTensor<typename Gemm::GemmKernel::ElementSF, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorC = cutlass::HostTensor<typename Gemm::ElementC, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorD = cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, cutlass::layout::PackedVectorLayout>;
|
||||
std::vector<HostTensorA> block_A;
|
||||
std::vector<HostTensorB> block_B;
|
||||
std::vector<HostTensorSF> block_SFA;
|
||||
std::vector<HostTensorSF> block_SFB;
|
||||
std::vector<HostTensorC> block_C;
|
||||
std::vector<HostTensorD> block_D;
|
||||
std::vector<HostTensorSF> block_SFD;
|
||||
std::vector<HostTensorD> block_ref_D;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const typename Gemm::GemmKernel::ElementSF *> ptr_SFA;
|
||||
cutlass::DeviceAllocation<const typename Gemm::GemmKernel::ElementSF *> ptr_SFB;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::GemmKernel::ElementSF *> ptr_SFD;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
|
||||
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
// A matrix wide constant value to scale the output matrix
|
||||
// Avoids generating small FP4 values.
|
||||
// NormConst is a single device-side constant value, its not per-batch or per-group
|
||||
cutlass::DeviceAllocation<ElementAccumulator> norm_constant_device;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool verification = true;
|
||||
bool use_pdl = false;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
float norm_constant = 1.0;
|
||||
int iterations = 10;
|
||||
int m = 1024, n = 2048, k = 512, groups = 10;
|
||||
dim3 cluster_shape = dim3(2,1,1);
|
||||
dim3 cluster_shape_fallback = dim3(2,1,1);
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::AlongN;
|
||||
int max_sm_count = INT_MAX;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementInput>::value;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("no_verif")) {
|
||||
verification = false;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("use_pdl")) {
|
||||
use_pdl = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0));
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
cmd.get_cmd_line_argument("cluster_m", cluster_shape.x);
|
||||
cmd.get_cmd_line_argument("cluster_n", cluster_shape.y);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y);
|
||||
cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster_order = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster_order = RasterOrderOptions::AlongM;
|
||||
}
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_host.reserve(groups);
|
||||
|
||||
for (int i = groups; i > 0; i--) {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "75_blackwell_grouped_gemm_block_scaled\n\n"
|
||||
<< " Blackwell Block Scaled Narrow Precision Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --norm_constant=<f32> Epilogue scalar normalization constant for the output matrix\n\n"
|
||||
<< " --cluster_m=<int> and --cluster_n=<int> Sets the X,Y dims of the preferred cluster shape\n"
|
||||
<< " --cluster_fallback_m=<int> and --cluster_fallback_n=<int> Sets the X,Y dims of the fallback cluster shape\n\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
|
||||
<< " --no_verif Do not run (host-side) verification kernels\n"
|
||||
<< " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "75_blackwell_grouped_gemm_block_scaled" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = uint64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes_host) {
|
||||
fmas += static_cast<uint64_t>(get<0>(problem)) *
|
||||
static_cast<uint64_t>(get<1>(problem)) *
|
||||
static_cast<uint64_t>(get<2>(problem));
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(const Options &options) {
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
|
||||
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
|
||||
|
||||
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
stride_A_host.push_back(stride_A);
|
||||
stride_B_host.push_back(stride_B);
|
||||
layout_SFA_host.push_back(layout_SFA);
|
||||
layout_SFB_host.push_back(layout_SFB);
|
||||
stride_C_host.push_back(stride_C);
|
||||
stride_D_host.push_back(stride_D);
|
||||
|
||||
block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A))));
|
||||
block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B))));
|
||||
block_SFA.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFA)))));
|
||||
block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB)))));
|
||||
block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C))));
|
||||
block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
|
||||
block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD)))));
|
||||
block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
|
||||
}
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
uint64_t seed = 2020;
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<typename Gemm::ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<typename Gemm::ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<typename Gemm::GemmKernel::ElementSF *> ptr_SFA_host(options.groups);
|
||||
std::vector<typename Gemm::GemmKernel::ElementSF *> ptr_SFB_host(options.groups);
|
||||
std::vector<typename Gemm::ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D_host(options.groups);
|
||||
std::vector<typename Gemm::GemmKernel::ElementSF *> ptr_SFD_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
initialize_block(block_A.at(i).host_view(), seed + 2021);
|
||||
initialize_block(block_B.at(i).host_view(), seed + 2022);
|
||||
initialize_block(block_C.at(i).host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.at(i).host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.at(i).host_view(), seed + 2025);
|
||||
|
||||
block_A.at(i).sync_device();
|
||||
block_B.at(i).sync_device();
|
||||
block_C.at(i).sync_device();
|
||||
block_SFA.at(i).sync_device();
|
||||
block_SFB.at(i).sync_device();
|
||||
|
||||
ptr_A_host.at(i) = block_A.at(i).device_data();
|
||||
ptr_B_host.at(i) = block_B.at(i).device_data();
|
||||
ptr_SFA_host.at(i) = block_SFA.at(i).device_data();
|
||||
ptr_SFB_host.at(i) = block_SFB.at(i).device_data();
|
||||
ptr_C_host.at(i) = block_C.at(i).device_data();
|
||||
ptr_D_host.at(i) = block_D.at(i).device_data();
|
||||
ptr_SFD_host.at(i) = block_SFD.at(i).device_data();
|
||||
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_SFA.reset(options.groups);
|
||||
ptr_SFA.copy_from_host(ptr_SFA_host.data());
|
||||
|
||||
ptr_SFB.reset(options.groups);
|
||||
ptr_SFB.copy_from_host(ptr_SFB_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_SFD.reset(options.groups);
|
||||
ptr_SFD.copy_from_host(ptr_SFD_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
layout_SFA.reset(options.groups);
|
||||
layout_SFA.copy_from_host(layout_SFA_host.data());
|
||||
|
||||
layout_SFB.reset(options.groups);
|
||||
layout_SFB.copy_from_host(layout_SFB_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
norm_constant_device.reset(1);
|
||||
norm_constant_device.copy_from_host(&options.norm_constant);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count);
|
||||
|
||||
if (!is_static_v<ClusterShape>) {
|
||||
if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 &&
|
||||
(options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) {
|
||||
std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl;
|
||||
}
|
||||
hw_info.cluster_shape = options.cluster_shape;
|
||||
hw_info.cluster_shape_fallback = options.cluster_shape_fallback;
|
||||
}
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
// If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
if (options.alpha != FLT_MAX){
|
||||
// Single alpha for all groups
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
// Only one alpha per each group
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
}
|
||||
if (options.beta != FLT_MAX) {
|
||||
// Single beta for all groups
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dBeta = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// Only one beta per each group
|
||||
fusion_args.dBeta = {_0{}, _0{}, 1};
|
||||
}
|
||||
// Output Block SF
|
||||
// fusion_args.block_scale_factor_ptr = ptr_SFD.get(); // Enable for SF Output
|
||||
// fusion_args.norm_constant_ptr = norm_constant_device.get(); // Enable for SF Output
|
||||
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = options.raster_order;
|
||||
|
||||
if (host_problem_shapes_available) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
bool passed = true;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
|
||||
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
|
||||
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.at(i).host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB);
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
|
||||
decltype(tensor_A),
|
||||
decltype(tensor_SFA),
|
||||
decltype(tensor_B),
|
||||
decltype(tensor_SFB)
|
||||
>
|
||||
mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C);
|
||||
auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
float, float,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
decltype(tensor_C), decltype(tensor_ref_D)
|
||||
> epilogue_params{};
|
||||
|
||||
epilogue_params.C = tensor_C;
|
||||
epilogue_params.D = tensor_ref_D;
|
||||
epilogue_params.alpha = alpha_host.at(i);
|
||||
epilogue_params.beta = beta_host.at(i);
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
block_D.at(i).sync_host();
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view());
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
if (options.verification) {
|
||||
std::cout << " Host-side verification is now running - may be very slow for large cases." << std::endl;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::cout << " Verfication is turned off for this run." << std::endl;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl));
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 ||
|
||||
((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)
|
||||
)
|
||||
) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running kernel with 1SM MMA config:" << std::endl;
|
||||
run<Gemm1SM>(options, false /*host_problem_shapes_available*/);
|
||||
std::cout << "Running kernel with 2SM MMA config:" << std::endl;
|
||||
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
88
examples/75_blackwell_grouped_gemm/CMakeLists.txt
Normal file
88
examples/75_blackwell_grouped_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=50 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=51 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
75_blackwell_grouped_gemm
|
||||
75_blackwell_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
75_blackwell_grouped_gemm_block_scaled
|
||||
75_blackwell_grouped_gemm_block_scaled.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
)
|
||||
endif()
|
||||
534
examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu
Normal file
534
examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu
Normal file
@ -0,0 +1,534 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Simple dgrad convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a dgrad convolution kernel using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation
|
||||
|
||||
This example instantiates a simple dgrad kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16.
|
||||
Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/76_blackwell_conv/76_blackwell_conv_dgrad --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0
|
||||
--pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/conv/dispatch_policy.hpp"
|
||||
#include "cutlass/conv/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/conv/device/conv_universal_adapter.hpp"
|
||||
#include "cutlass/conv/kernel/conv_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Activation matrix configuration
|
||||
using ElementAct = half_t; // Element type for activation matrix
|
||||
constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Weight/Filter matrix configuration
|
||||
using ElementFlt = half_t; // Element type for weight/filter matrix operand
|
||||
constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Xformed activation matrix configuration
|
||||
using ElementXformedAct = half_t; // Element type for xformed activation matrix operand
|
||||
constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits<ElementXformedAct>::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Layout of matrix A/B/C in gemm's perspecitive.
|
||||
using LayoutA = cutlass::layout::TensorNDHWC;
|
||||
using LayoutB = cutlass::layout::TensorNDHWC;
|
||||
using LayoutC = cutlass::layout::TensorNDHWC;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kDgrad; // Convolution operation
|
||||
|
||||
// Kernel Perf config
|
||||
using TileShape = Shape<_128,_128,Shape<_64>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementAct, LayoutC, AlignmentAct,
|
||||
ElementAct, LayoutC, AlignmentAct,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ConvOp,
|
||||
ElementXformedAct, LayoutA, AlignmentXformedAct,
|
||||
ElementFlt, LayoutB, AlignmentFlt,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::conv::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using ProblemShape=cutlass::conv::ConvProblemShape<ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
|
||||
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;
|
||||
|
||||
using StrideC = typename Conv::ConvKernel::StrideC;
|
||||
using StrideD = typename Conv::ConvKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_A;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_B;
|
||||
cutlass::DeviceAllocation<ElementAct> block_C;
|
||||
cutlass::DeviceAllocation<ElementAct> block_D;
|
||||
cutlass::DeviceAllocation<ElementAct> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int n, d, h, w, c, k, t, r, s, z, p, q;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3),
|
||||
pad_d(0), pad_h(1), pad_w(1),
|
||||
stride_d(1), stride_h(1), stride_w(1),
|
||||
dilation_d(1), dilation_h(1), dilation_w(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("d", d);
|
||||
cmd.get_cmd_line_argument("h", h);
|
||||
cmd.get_cmd_line_argument("w", w);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("t", t);
|
||||
cmd.get_cmd_line_argument("r", r);
|
||||
cmd.get_cmd_line_argument("s", s);
|
||||
cmd.get_cmd_line_argument("pad_d", pad_d);
|
||||
cmd.get_cmd_line_argument("pad_h", pad_h);
|
||||
cmd.get_cmd_line_argument("pad_w", pad_w);
|
||||
cmd.get_cmd_line_argument("stride_d", stride_d);
|
||||
cmd.get_cmd_line_argument("stride_h", stride_h);
|
||||
cmd.get_cmd_line_argument("stride_w", stride_w);
|
||||
cmd.get_cmd_line_argument("dilation_d", dilation_d);
|
||||
cmd.get_cmd_line_argument("dilation_h", dilation_h);
|
||||
cmd.get_cmd_line_argument("dilation_w", dilation_w);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
// Calculate z,p,q based on inputs.
|
||||
z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
|
||||
p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
|
||||
q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "76_blackwell_conv_dgrad\n\n"
|
||||
<< " Blackwell FP16 dgrad convolution using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --n=<int> Sets the batch size of the Activation\n"
|
||||
<< " --d=<int> Sets the depth size of the Activation\n"
|
||||
<< " --h=<int> Sets the height of the Activation\n"
|
||||
<< " --w=<int> Sets the width of the Activation\n"
|
||||
<< " --c=<int> Sets the channel size of the Activation\n"
|
||||
<< " --k=<int> Sets the image numbers of the Filter\n"
|
||||
<< " --t=<int> Sets the depth size of the Filter\n"
|
||||
<< " --r=<int> Sets the height of the Filter\n"
|
||||
<< " --s=<int> Sets the width of the Filter\n"
|
||||
<< " --pad_d=<int> Sets the padding size in depth\n"
|
||||
<< " --pad_h=<int> Sets the padding size in height\n"
|
||||
<< " --pad_w=<int> Sets the padding size in width\n"
|
||||
<< " --stride_d=<int> Sets the traversal stride size in depth\n"
|
||||
<< " --stride_h=<int> Sets the traversal stride size in height\n"
|
||||
<< " --stride_w=<int> Sets the traversal stride size in width\n"
|
||||
<< " --dialtion_d=<int> Sets the filter dilation size in depth\n"
|
||||
<< " --dialtion_h=<int> Sets the filter dilation size in height\n"
|
||||
<< " --dialtion_w=<int> Sets the filter dilation size in width\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "76_blackwell_conv_dgrad" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0"
|
||||
<< " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * (n * d * h * w) * c * (t * r * s * k);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the Conv and reference Conv
|
||||
void initialize(const Options &options) {
|
||||
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
// Setup stride_C/D
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
|
||||
block_A.reset(problem_shape.size_A());
|
||||
block_B.reset(problem_shape.size_B());
|
||||
block_C.reset(problem_shape.size_C());
|
||||
block_D.reset(problem_shape.size_C());
|
||||
block_ref_D.reset(problem_shape.size_C());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Conv::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
typename Conv::Arguments arguments{
|
||||
problem_shape,
|
||||
{block_A.get(), block_B.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutC::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Construct Conv3dProblemSize with user defined inputs.
|
||||
cutlass::conv::Conv3dProblemSize problem_size(
|
||||
cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc
|
||||
cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc
|
||||
cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding
|
||||
cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w)
|
||||
cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w)
|
||||
cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk
|
||||
);
|
||||
|
||||
// Launch device reference conv kernel
|
||||
cutlass::reference::device::Conv3dDgrad(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Conv conv;
|
||||
|
||||
// Create a structure of conv kernel arguments suitable for invoking an instance of Conv
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Conv::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(conv.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(conv.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(conv.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size:" << std::endl;
|
||||
std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), ";
|
||||
std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), ";
|
||||
std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Conv>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
534
examples/76_blackwell_conv/76_blackwell_conv_fprop.cu
Normal file
534
examples/76_blackwell_conv/76_blackwell_conv_fprop.cu
Normal file
@ -0,0 +1,534 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Simple fprop convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a fprop convolution kernel using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of fprop convolution kernel is, take 3D convolution as an example:
|
||||
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation
|
||||
|
||||
This example instantiates a simple fprop kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16.
|
||||
Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/76_blackwell_conv/76_blackwell_conv_fprop --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0
|
||||
--pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/conv/dispatch_policy.hpp"
|
||||
#include "cutlass/conv/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/conv/device/conv_universal_adapter.hpp"
|
||||
#include "cutlass/conv/kernel/conv_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Activation matrix configuration
|
||||
using ElementAct = half_t; // Element type for activation matrix
|
||||
constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Weight/Filter matrix configuration
|
||||
using ElementFlt = half_t; // Element type for weight/filter matrix operand
|
||||
constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Xformed activation matrix configuration
|
||||
using ElementXformedAct = half_t; // Element type for xformed activation matrix operand
|
||||
constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits<ElementXformedAct>::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Layout of matrix A/B/C in gemm's perspecitive.
|
||||
using LayoutA = cutlass::layout::TensorNDHWC;
|
||||
using LayoutB = cutlass::layout::TensorNDHWC;
|
||||
using LayoutC = cutlass::layout::TensorNDHWC;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kFprop; // Convolution operation
|
||||
|
||||
// Kernel Perf config
|
||||
using TileShape = Shape<_128,_128,Shape<_64>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementXformedAct, LayoutC, AlignmentXformedAct,
|
||||
ElementXformedAct, LayoutC, AlignmentXformedAct,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ConvOp,
|
||||
ElementAct, LayoutA, AlignmentAct,
|
||||
ElementFlt, LayoutB, AlignmentFlt,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::conv::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using ProblemShape=cutlass::conv::ConvProblemShape<ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
|
||||
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;
|
||||
|
||||
using StrideC = typename Conv::ConvKernel::StrideC;
|
||||
using StrideD = typename Conv::ConvKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAct> block_A;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_B;
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_C;
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_D;
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int n, d, h, w, c, k, t, r, s, z, p, q;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3),
|
||||
pad_d(0), pad_h(1), pad_w(1),
|
||||
stride_d(1), stride_h(1), stride_w(1),
|
||||
dilation_d(1), dilation_h(1), dilation_w(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("d", d);
|
||||
cmd.get_cmd_line_argument("h", h);
|
||||
cmd.get_cmd_line_argument("w", w);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("t", t);
|
||||
cmd.get_cmd_line_argument("r", r);
|
||||
cmd.get_cmd_line_argument("s", s);
|
||||
cmd.get_cmd_line_argument("pad_d", pad_d);
|
||||
cmd.get_cmd_line_argument("pad_h", pad_h);
|
||||
cmd.get_cmd_line_argument("pad_w", pad_w);
|
||||
cmd.get_cmd_line_argument("stride_d", stride_d);
|
||||
cmd.get_cmd_line_argument("stride_h", stride_h);
|
||||
cmd.get_cmd_line_argument("stride_w", stride_w);
|
||||
cmd.get_cmd_line_argument("dilation_d", dilation_d);
|
||||
cmd.get_cmd_line_argument("dilation_h", dilation_h);
|
||||
cmd.get_cmd_line_argument("dilation_w", dilation_w);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
// Calculate z,p,q based on inputs.
|
||||
z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
|
||||
p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
|
||||
q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "76_blackwell_conv_fprop\n\n"
|
||||
<< " Blackwell FP16 fprop convolution using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --n=<int> Sets the batch size of the Activation\n"
|
||||
<< " --d=<int> Sets the depth size of the Activation\n"
|
||||
<< " --h=<int> Sets the height of the Activation\n"
|
||||
<< " --w=<int> Sets the width of the Activation\n"
|
||||
<< " --c=<int> Sets the channel size of the Activation\n"
|
||||
<< " --k=<int> Sets the image numbers of the Filter\n"
|
||||
<< " --t=<int> Sets the depth size of the Filter\n"
|
||||
<< " --r=<int> Sets the height of the Filter\n"
|
||||
<< " --s=<int> Sets the width of the Filter\n"
|
||||
<< " --pad_d=<int> Sets the padding size in depth\n"
|
||||
<< " --pad_h=<int> Sets the padding size in height\n"
|
||||
<< " --pad_w=<int> Sets the padding size in width\n"
|
||||
<< " --stride_d=<int> Sets the traversal stride size in depth\n"
|
||||
<< " --stride_h=<int> Sets the traversal stride size in height\n"
|
||||
<< " --stride_w=<int> Sets the traversal stride size in width\n"
|
||||
<< " --dialtion_d=<int> Sets the filter dilation size in depth\n"
|
||||
<< " --dialtion_h=<int> Sets the filter dilation size in height\n"
|
||||
<< " --dialtion_w=<int> Sets the filter dilation size in width\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "76_blackwell_conv_fprop" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0"
|
||||
<< " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * (n * z * p * q) * k * (t * r * s * c);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the Conv and reference Conv
|
||||
void initialize(const Options &options) {
|
||||
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
// Setup stride_C/D
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
|
||||
block_A.reset(problem_shape.size_A());
|
||||
block_B.reset(problem_shape.size_B());
|
||||
block_C.reset(problem_shape.size_C());
|
||||
block_D.reset(problem_shape.size_C());
|
||||
block_ref_D.reset(problem_shape.size_C());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Conv::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
typename Conv::Arguments arguments{
|
||||
problem_shape,
|
||||
{block_A.get(), block_B.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutC::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Construct Conv3dProblemSize with user defined inputs.
|
||||
cutlass::conv::Conv3dProblemSize problem_size(
|
||||
cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc
|
||||
cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc
|
||||
cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding
|
||||
cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w)
|
||||
cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w)
|
||||
cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk
|
||||
);
|
||||
|
||||
// Launch device reference conv kernel
|
||||
cutlass::reference::device::Conv3dFprop(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Conv conv;
|
||||
|
||||
// Create a structure of conv kernel arguments suitable for invoking an instance of Conv
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Conv::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(conv.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(conv.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(conv.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size:" << std::endl;
|
||||
std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), ";
|
||||
std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), ";
|
||||
std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Conv>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
530
examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu
Normal file
530
examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu
Normal file
@ -0,0 +1,530 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Simple wgrad convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a wgrad convolution kernel using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter
|
||||
|
||||
This example instantiates a simple wgrad kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16.
|
||||
Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/76_blackwell_conv/76_blackwell_conv_wgrad --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0
|
||||
--pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/conv/dispatch_policy.hpp"
|
||||
#include "cutlass/conv/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/conv/device/conv_universal_adapter.hpp"
|
||||
#include "cutlass/conv/kernel/conv_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Activation matrix configuration
|
||||
using ElementAct = half_t; // Element type for activation matrix
|
||||
constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Weight/Filter matrix configuration
|
||||
using ElementFlt = half_t; // Element type for weight/filter matrix operand
|
||||
constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Xformed activation matrix configuration
|
||||
using ElementXformedAct = half_t; // Element type for xformed activation matrix operand
|
||||
constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits<ElementXformedAct>::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Layout of matrix A/B/C in gemm's perspecitive.
|
||||
using LayoutA = cutlass::layout::TensorNDHWC;
|
||||
using LayoutB = cutlass::layout::TensorNDHWC;
|
||||
using LayoutC = cutlass::layout::TensorKCSRT;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kWgrad; // Convolution operation
|
||||
|
||||
// Kernel Perf config
|
||||
using TileShape = Shape<_128,Shape<_128>,Shape<_64>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementFlt, LayoutC, AlignmentFlt,
|
||||
ElementFlt, LayoutC, AlignmentFlt,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ConvOp,
|
||||
ElementXformedAct, LayoutA, AlignmentXformedAct,
|
||||
ElementAct, LayoutB, AlignmentAct,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::conv::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using ProblemShape=cutlass::conv::ConvProblemShape<ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
|
||||
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;
|
||||
|
||||
using StrideC = typename Conv::ConvKernel::StrideC;
|
||||
using StrideD = typename Conv::ConvKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_A;
|
||||
cutlass::DeviceAllocation<ElementAct> block_B;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_C;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_D;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int n, d, h, w, c, k, t, r, s, z, p, q;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3),
|
||||
pad_d(0), pad_h(1), pad_w(1),
|
||||
stride_d(1), stride_h(1), stride_w(1),
|
||||
dilation_d(1), dilation_h(1), dilation_w(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("d", d);
|
||||
cmd.get_cmd_line_argument("h", h);
|
||||
cmd.get_cmd_line_argument("w", w);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("t", t);
|
||||
cmd.get_cmd_line_argument("r", r);
|
||||
cmd.get_cmd_line_argument("s", s);
|
||||
cmd.get_cmd_line_argument("pad_d", pad_d);
|
||||
cmd.get_cmd_line_argument("pad_h", pad_h);
|
||||
cmd.get_cmd_line_argument("pad_w", pad_w);
|
||||
cmd.get_cmd_line_argument("stride_d", stride_d);
|
||||
cmd.get_cmd_line_argument("stride_h", stride_h);
|
||||
cmd.get_cmd_line_argument("stride_w", stride_w);
|
||||
cmd.get_cmd_line_argument("dilation_d", dilation_d);
|
||||
cmd.get_cmd_line_argument("dilation_h", dilation_h);
|
||||
cmd.get_cmd_line_argument("dilation_w", dilation_w);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
// Calculate z,p,q based on inputs.
|
||||
z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
|
||||
p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
|
||||
q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "76_blackwell_conv_wgrad\n\n"
|
||||
<< " Blackwell FP16 wgrad convolution using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --n=<int> Sets the batch size of the Activation\n"
|
||||
<< " --d=<int> Sets the depth size of the Activation\n"
|
||||
<< " --h=<int> Sets the height of the Activation\n"
|
||||
<< " --w=<int> Sets the width of the Activation\n"
|
||||
<< " --c=<int> Sets the channel size of the Activation\n"
|
||||
<< " --k=<int> Sets the image numbers of the Filter\n"
|
||||
<< " --t=<int> Sets the depth size of the Filter\n"
|
||||
<< " --r=<int> Sets the height of the Filter\n"
|
||||
<< " --s=<int> Sets the width of the Filter\n"
|
||||
<< " --pad_d=<int> Sets the padding size in depth\n"
|
||||
<< " --pad_h=<int> Sets the padding size in height\n"
|
||||
<< " --pad_w=<int> Sets the padding size in width\n"
|
||||
<< " --stride_d=<int> Sets the traversal stride size in depth\n"
|
||||
<< " --stride_h=<int> Sets the traversal stride size in height\n"
|
||||
<< " --stride_w=<int> Sets the traversal stride size in width\n"
|
||||
<< " --dialtion_d=<int> Sets the filter dilation size in depth\n"
|
||||
<< " --dialtion_h=<int> Sets the filter dilation size in height\n"
|
||||
<< " --dialtion_w=<int> Sets the filter dilation size in width\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "76_blackwell_conv_wgrad" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0"
|
||||
<< " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * k * (t * r * s * c) * (n * z * p * q);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the Conv and reference Conv
|
||||
void initialize(const Options &options) {
|
||||
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
// Setup stride_C/D
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
|
||||
|
||||
block_A.reset(problem_shape.size_A());
|
||||
block_B.reset(problem_shape.size_B());
|
||||
block_C.reset(problem_shape.size_C());
|
||||
block_D.reset(problem_shape.size_C());
|
||||
block_ref_D.reset(problem_shape.size_C());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Conv::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
typename Conv::Arguments arguments{
|
||||
problem_shape,
|
||||
{block_A.get(), block_B.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), LayoutA::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutA::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Construct Conv3dProblemSize with user defined inputs.
|
||||
cutlass::conv::Conv3dProblemSize problem_size(
|
||||
cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc
|
||||
cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc
|
||||
cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding
|
||||
cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w)
|
||||
cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w)
|
||||
cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk
|
||||
);
|
||||
|
||||
// Launch device reference conv kernel
|
||||
cutlass::reference::device::Conv3dWgrad(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Conv conv;
|
||||
|
||||
// Create a structure of conv kernel arguments suitable for invoking an instance of Conv
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Conv::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(conv.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(conv.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(conv.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size:" << std::endl;
|
||||
std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), ";
|
||||
std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), ";
|
||||
std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Conv>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
46
examples/76_blackwell_conv/CMakeLists.txt
Normal file
46
examples/76_blackwell_conv/CMakeLists.txt
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_ARCHS MATCHES 100a)
|
||||
cutlass_example_add_executable(
|
||||
76_blackwell_conv_fprop
|
||||
76_blackwell_conv_fprop.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
76_blackwell_conv_dgrad
|
||||
76_blackwell_conv_dgrad.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
76_blackwell_conv_wgrad
|
||||
76_blackwell_conv_wgrad.cu
|
||||
)
|
||||
endif()
|
||||
987
examples/77_blackwell_fmha/77_blackwell_fmha.cu
Normal file
987
examples/77_blackwell_fmha/77_blackwell_fmha.cu
Normal file
@ -0,0 +1,987 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example implementation of fused multi-head attention for the NVIDIA Blackwell SM100
|
||||
architecture using CUTLASS 3.
|
||||
|
||||
MQA/GQA
|
||||
-------
|
||||
|
||||
The head dimension can be represented as a tuple, where the K/V strides in the
|
||||
first dimension is zero. This has the effect of MQA or GQA.
|
||||
* MHA is (head_size:head_stride).
|
||||
* MQA is (head_size:head_stride) in Q and (head_size:_0) in K and V.
|
||||
* GQA is (grouped_heads,heads_kv):(head_stride,grouped_heads*head_stride) in Q
|
||||
and (grouped_heads,heads_kv):(0,head_stride) in K and V
|
||||
|
||||
Output Scale
|
||||
------------
|
||||
|
||||
The output scale gets passed to the collective mainloop, and is applied
|
||||
using FP32 compute pre-quantization
|
||||
|
||||
Variable Sequence Length
|
||||
------------------------
|
||||
|
||||
For variable sequence length, pass in VariableLength objects
|
||||
(max_seqlen, cumulative_seqlen_ptr) in the problem shape for
|
||||
seqlen Q and KV.
|
||||
|
||||
Support
|
||||
---------
|
||||
|
||||
Right now e4m3 with fp32 compute is using a 256x256 tiling and a head dimension
|
||||
of 128 is supported.
|
||||
|
||||
|
||||
Example usage:
|
||||
$ ./examples/77_blackell_fmha/77_blackell_fmha_fp8 \
|
||||
--b=2048 --h=2048 --d=2048 --q=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "reference/fmha_fwd_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "device/fmha.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp"
|
||||
#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp"
|
||||
#include "kernel/fmha_options.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
using namespace cutlass::fmha::collective;
|
||||
using namespace cutlass::fmha;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 1;
|
||||
int h = 1;
|
||||
int h_k = 1;
|
||||
int q = 256;
|
||||
int k = 256;
|
||||
int d = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_k = InitStyle::kRandom;
|
||||
InitStyle init_style_v = InitStyle::kRandom;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
if (q == -1) q = k;
|
||||
if (k == -1) k = q;
|
||||
if (q == -1 && k == -1) q = k = defaults.q;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
causal = residual = false;
|
||||
if (varlen) {
|
||||
residual = true;
|
||||
}
|
||||
}
|
||||
else if (mask == "causal") {
|
||||
residual = false;
|
||||
causal = true;
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
causal = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k);
|
||||
get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v);
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_fmha\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " fused multi-head attention forward-passkernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
<< " with the last batch sized to make it fit\n"
|
||||
<< " implies at least residual masking for correctness\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (i % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
double tops_exp2_s = 0;
|
||||
double tbytes_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
bool kIsVarlen,
|
||||
class TileShape,
|
||||
class DispatchPolicy,
|
||||
class ActiveMask,
|
||||
class... KernelOptions
|
||||
>
|
||||
struct FwdRunner {
|
||||
|
||||
#ifdef FP8
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
#endif
|
||||
|
||||
using ElementAccumulatorQK = float;
|
||||
using ElementAccumulatorPV = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
// Q K D (B H)
|
||||
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
|
||||
|
||||
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R B)
|
||||
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D (H_G H_R B)
|
||||
using StrideV = StrideK;
|
||||
using StrideO = StrideQ;
|
||||
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q (H_G H_R B)
|
||||
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
|
||||
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
|
||||
|
||||
using Mainloop =
|
||||
cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized<
|
||||
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||
TileShape, StrideQ, StrideK, StrideV,
|
||||
ActiveMask
|
||||
>;
|
||||
using Operation = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<
|
||||
ProblemShapeType,
|
||||
Mainloop,
|
||||
cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<
|
||||
ElementOut, ElementAccumulatorPV,
|
||||
typename Mainloop::TileShapePV,
|
||||
StrideO, StrideLSE
|
||||
>,
|
||||
TileScheduler
|
||||
>>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideQ stride_Q;
|
||||
StrideK stride_K;
|
||||
StrideV stride_V;
|
||||
StrideO stride_O;
|
||||
StrideLSE stride_LSE;
|
||||
uint64_t seed = 0;
|
||||
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
|
||||
std::vector<int> cumulative_seqlen_q;
|
||||
std::vector<int> cumulative_seqlen_kv;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
bool passed_LSE = true; // future work
|
||||
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
// if ( ! passed_LSE) {
|
||||
// std::cerr << "failed LSE: max diff " << max_diff
|
||||
// << " mean " << mean_diff << std::endl;
|
||||
// }
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) {
|
||||
int num_batches = get<3,1>(problem_size);
|
||||
|
||||
// generate Q as --b times
|
||||
// gaussian (--Q, --Q / 2) sampled positive
|
||||
// track cumulative
|
||||
std::mt19937 rng(0x202305151552ull);
|
||||
std::normal_distribution<double> dist_q(get<0>(problem_size), get<0>(problem_size) / 2);
|
||||
std::normal_distribution<double> dist_kv(get<1>(problem_size), get<1>(problem_size) / 2);
|
||||
std::cout << "N: " << num_batches << ", Q: " << get<0>(problem_size) << ", KV: " << get<1>(problem_size) << std::endl;
|
||||
|
||||
auto generate_positive_int = [](auto& dist, auto& gen) {
|
||||
int result = 0;
|
||||
do {
|
||||
result = static_cast<int>(dist(gen));
|
||||
} while (result <= 0);
|
||||
return result;
|
||||
};
|
||||
|
||||
cumulative_seqlen_q = {0};
|
||||
cumulative_seqlen_kv = {0};
|
||||
|
||||
int total_seqlen_q = 0;
|
||||
int total_seqlen_kv = 0;
|
||||
int max_seqlen_q = 0;
|
||||
int max_seqlen_kv = 0;
|
||||
|
||||
for (int i = 0; i < num_batches; i++) {
|
||||
int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng);
|
||||
|
||||
total_seqlen_q += seqlen_q;
|
||||
total_seqlen_kv += seqlen_kv;
|
||||
|
||||
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
|
||||
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
|
||||
|
||||
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
|
||||
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
|
||||
}
|
||||
std::cout << "Q max: " << max_seqlen_q << " total: " << total_seqlen_q << " vs even " << num_batches * get<0>(problem_size) << std::endl;
|
||||
std::cout << "KV max: " << max_seqlen_kv << " total: " << total_seqlen_kv << " vs even " << num_batches * get<1>(problem_size) << std::endl;
|
||||
|
||||
ProblemShape problem_size_for_init = problem_size;
|
||||
get<3,1>(problem_size_for_init) = 1;
|
||||
get<0>(problem_size_for_init) = total_seqlen_q;
|
||||
get<1>(problem_size_for_init) = total_seqlen_kv;
|
||||
|
||||
ProblemShapeType problem_size_for_launch;
|
||||
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
|
||||
get<2>(problem_size_for_launch) = get<2>(problem_size);
|
||||
get<3>(problem_size_for_launch) = get<3>(problem_size);
|
||||
|
||||
return cute::make_tuple(problem_size_for_init, problem_size_for_launch);
|
||||
}
|
||||
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
|
||||
ProblemShapeType initialize(const Options& options) {
|
||||
int h_r = options.h / options.h_k;
|
||||
assert(options.h % options.h_k == 0);
|
||||
auto problem_shape_in = cute::make_tuple(options.q, options.k, options.d, cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));
|
||||
|
||||
ProblemShapeType problem_shape;
|
||||
decltype(problem_shape_in) problem_size;
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
|
||||
problem_shape = problem_shape_launch;
|
||||
problem_size = problem_shape_init;
|
||||
}
|
||||
else {
|
||||
problem_size = problem_shape_in;
|
||||
problem_shape = problem_shape_in;
|
||||
}
|
||||
|
||||
get<2>(problem_size) = cutlass::round_up(get<2>(problem_size), 8); // alignment
|
||||
|
||||
auto shape_QO = select<0,2,3>(problem_size);
|
||||
auto shape_KV = select<1,2,3>(problem_size);
|
||||
auto shape_LSE = select<0,3>(problem_size);
|
||||
|
||||
int SQ = size<0>(problem_size);
|
||||
int SK = size<1>(problem_size);
|
||||
int D = size<2>(problem_size);
|
||||
int H = size<3,0>(problem_size);
|
||||
int H_K = size<3,0,1>(problem_size);
|
||||
int H_Q = size<3,0,0>(problem_size);
|
||||
int B = size<3,1>(problem_size);
|
||||
|
||||
stride_Q = make_stride(H*D , _1{}, make_stride(make_stride(D, H_Q*D), H*D*SQ));
|
||||
stride_O = stride_Q;
|
||||
stride_K = make_stride(H_K*D , _1{}, make_stride(make_stride(_0{}, D), H_K*D*SK));
|
||||
stride_V = stride_K;
|
||||
stride_LSE = make_stride(_1{}, make_stride(make_stride(SQ, SQ*H_Q), SQ*H));
|
||||
|
||||
if (kIsVarlen) {
|
||||
get<2,1>(stride_Q) = 0;
|
||||
get<2,1>(stride_K) = 0;
|
||||
get<2,1>(stride_V) = 0;
|
||||
get<2,1>(stride_O) = 0;
|
||||
get<1,1>(stride_LSE) = 0;
|
||||
}
|
||||
|
||||
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_LSE.reset(size(shape_LSE));
|
||||
block_ref_O.reset(size(shape_QO));
|
||||
block_ref_LSE.reset(size(shape_LSE));
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
}
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShapeType problem_shape = initialize(options);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ block_Q.get(), stride_Q,
|
||||
block_K.get(), stride_K,
|
||||
block_V.get(), stride_V },
|
||||
{ block_O.get(), stride_O,
|
||||
block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops;
|
||||
if (kIsVarlen) {
|
||||
flops = 0.0;
|
||||
for (int i = 0; i < size<3,1>(problem_shape); i++) {
|
||||
flops += (cumulative_seqlen_q[i+1] - cumulative_seqlen_q[i])
|
||||
* 1.0
|
||||
* (cumulative_seqlen_kv[i+1] - cumulative_seqlen_kv[i]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
flops = 1.0;
|
||||
flops *= static_cast<double>(size<0>(problem_shape));
|
||||
flops *= static_cast<double>(size<1>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,1>(problem_shape));
|
||||
}
|
||||
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(size<2>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,0>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms << "ms, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
if (options.varlen) {
|
||||
FwdRunner<true, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
else
|
||||
{
|
||||
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
};
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
if (options.varlen) {
|
||||
FwdRunner<true, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
else
|
||||
{
|
||||
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
};
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if (options.varlen) {
|
||||
FwdRunner<true, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
else {
|
||||
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
};
|
||||
|
||||
using HeadDim = _32;
|
||||
|
||||
#ifdef FP8
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability major 10) and CUDA 12.8 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
auto with_mask = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMask{});
|
||||
}
|
||||
else {
|
||||
fn(NoMask{});
|
||||
}
|
||||
};
|
||||
|
||||
with_mask([&](auto fusion) {
|
||||
if (options.d <= 32) {
|
||||
run_fwd_32(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 64) {
|
||||
run_fwd_64(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 128) {
|
||||
run_fwd_128(fusion, options, hw_info);
|
||||
}
|
||||
else {
|
||||
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
865
examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu
Normal file
865
examples/77_blackwell_fmha/77_blackwell_fmha_bwd.cu
Normal file
@ -0,0 +1,865 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example implementation of fused multi-head attention for Blackwell using CUTLASS 3.
|
||||
|
||||
This example showcases the use of CUTLASS to build backward fused
|
||||
multi-head attantion (FMHA) collectives from existing CUTLASS collectives targeting
|
||||
the NVIDIA Blackwell architecture.
|
||||
|
||||
Background and motivation
|
||||
-------------------------
|
||||
CUTLASS is a highly flexible library that provides open-source building blocks
|
||||
for tensor core programming for GEMM or GEMM-like problems. Fused multi-head
|
||||
attention (FMHA) is a foundational kernel for large language models (LLMs) since it
|
||||
makes long sequence lengths feasible from a memory-usage perspective. It also
|
||||
improves computational efficiency since it transforms an outer-product-like and
|
||||
a matrix-vector-like GEMM into a fused operation with much higher arithmetic
|
||||
intensity. For more details, see Dao et al, 2022; Dao, 2023.
|
||||
Implementing this kernel in CUTLASS enabled easy customization and high
|
||||
performance.
|
||||
|
||||
Introduction
|
||||
------------
|
||||
The example targets the NVIDIA Blackwell architecture, and takes advantage of
|
||||
5th gen tensor cores and the Tensor Memory Accelerator (TMA), just like
|
||||
GEMMs do. It provides a backward pass (often abbreviated
|
||||
bwd in the code).
|
||||
The code is structured into three layers: The runner (and the reference kernels)
|
||||
takes care of initialization, measurement, and testing; the device layer
|
||||
orchestrates kernel calls and partitions workspace; and the kernel layer (just
|
||||
like the CUTLASS kernel layer.
|
||||
|
||||
Support
|
||||
---------
|
||||
|
||||
We support fp16 and fp8 data types with a head dimension of 128.
|
||||
|
||||
Example usage:
|
||||
$ ./examples/77_blackwell_fmha/77_blackwell_fmha_bwd_fp16 \
|
||||
--b=2048 --h=2048 --d=2048 --q=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "reference/fmha_fwd_reference.hpp"
|
||||
#include "reference/fmha_bwd_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "device/fmha_device_bwd.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
using namespace cutlass::fmha::collective;
|
||||
using namespace cutlass::fmha;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kZero, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 16;
|
||||
int h = 16;
|
||||
int h_k = 1;
|
||||
int q = 1024;
|
||||
int k = 1024;
|
||||
int d = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_k = InitStyle::kRandom;
|
||||
InitStyle init_style_v = InitStyle::kRandom;
|
||||
InitStyle init_style_do = InitStyle::kRandom;
|
||||
bool skip_reference = false;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "0") {
|
||||
dst = InitStyle::kZero;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
if (q == -1) q = k;
|
||||
if (k == -1) k = q;
|
||||
if (q == -1 && k == -1) q = k = defaults.q;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "causal") {
|
||||
causal = true;
|
||||
}
|
||||
else {
|
||||
causal = defaults.causal;
|
||||
}
|
||||
|
||||
skip_reference = cmd.check_cmd_line_flag("skip-reference");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_k);
|
||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_v);
|
||||
get_init_style_argument(cmd, "init-style", init_style_do, defaults.init_style_do);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k);
|
||||
get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v);
|
||||
get_init_style_argument(cmd, "init-style-do", init_style_v, init_style_do);
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_fmha_bwd\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " fused multi-head attention kernels for the backward pass targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|causal> Enables masking\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kZero: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 0);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (i % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class TileShape,
|
||||
class DispatchPolicy,
|
||||
class ActiveMask,
|
||||
class... KernelOptions
|
||||
>
|
||||
struct BwdRunner {
|
||||
|
||||
#ifdef FP8
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
#endif
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Q K D (H B)
|
||||
using ProblemShapeType = cute::tuple<int, int, int, cute::tuple<int, int>>;
|
||||
|
||||
using Operation = cutlass::fmha::device::Sm100FmhaBwd<Element, ElementAccumulator, TileShape, ActiveMask>;
|
||||
|
||||
using TensorStride = Stride<int, _1, Stride<int, int>>; // Seq D (H B)
|
||||
using StrideQ = TensorStride;
|
||||
using StrideK = TensorStride;
|
||||
using StrideV = TensorStride;
|
||||
using StrideO = TensorStride;
|
||||
using StrideLSE = Stride<_1, Stride<int, int>>; // Seq (H B)
|
||||
|
||||
// Backwards specific
|
||||
using StrideDQ = TensorStride;
|
||||
using StrideDK = TensorStride;
|
||||
using StrideDV = TensorStride;
|
||||
using StrideDO = TensorStride;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideQ stride_Q;
|
||||
StrideK stride_K;
|
||||
StrideV stride_V;
|
||||
StrideO stride_O;
|
||||
StrideLSE stride_LSE;
|
||||
|
||||
StrideDQ stride_dQ;
|
||||
StrideDK stride_dK;
|
||||
StrideDV stride_dV;
|
||||
StrideDO stride_dO;
|
||||
|
||||
uint64_t seed = 0;
|
||||
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<Element> block_O;
|
||||
DeviceAllocation<ElementAccumulator> block_LSE;
|
||||
|
||||
DeviceAllocation<Element> block_dQ;
|
||||
DeviceAllocation<Element> block_dK;
|
||||
DeviceAllocation<Element> block_dV;
|
||||
DeviceAllocation<Element> block_dO;
|
||||
|
||||
DeviceAllocation<Element> block_ref_dQ;
|
||||
DeviceAllocation<Element> block_ref_dK;
|
||||
DeviceAllocation<Element> block_ref_dV;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
// keep going here! (this might be better in cursor)
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_dQ);
|
||||
|
||||
Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_dK);
|
||||
|
||||
Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_dV);
|
||||
|
||||
Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_dO);
|
||||
|
||||
fmha_bwd_reference(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveMask{});
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-0 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_dQ, block_ref_dQ, max_diff, mean_diff);
|
||||
|
||||
bool passed_dQ = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_dQ) {
|
||||
std::cerr << "failed dQ: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_dK, block_ref_dK, max_diff, mean_diff);
|
||||
|
||||
bool passed_dK = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_dK) {
|
||||
std::cerr << "failed dK: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_dV, block_ref_dV, max_diff, mean_diff);
|
||||
|
||||
bool passed_dV = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_dV) {
|
||||
std::cerr << "failed dV: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
return passed_dQ && passed_dK && passed_dV;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const ProblemShapeType& problem_shape, Options const& options) {
|
||||
auto [Q, K, D, HB] = problem_shape;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
|
||||
auto shape_QO = select<0,2,3>(problem_shape);
|
||||
auto shape_KV = select<1,2,3>(problem_shape);
|
||||
auto shape_LSE = select<0,3>(problem_shape);
|
||||
|
||||
stride_Q = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
stride_K = make_stride(D, _1{}, make_stride(D*K, D*K*H));
|
||||
stride_V = stride_K;
|
||||
stride_O = stride_Q;
|
||||
stride_LSE = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
|
||||
stride_dQ = stride_Q;
|
||||
stride_dK = stride_K;
|
||||
stride_dV = stride_V;
|
||||
stride_dO = stride_O;
|
||||
|
||||
auto lsize = [](auto shape) {
|
||||
return size(make_shape(1ull, shape));
|
||||
};
|
||||
|
||||
block_Q.reset(lsize(shape_QO));
|
||||
block_K.reset(lsize(shape_KV));
|
||||
block_V.reset(lsize(shape_KV));
|
||||
block_O.reset(lsize(shape_QO));
|
||||
block_LSE.reset(lsize(shape_LSE));
|
||||
|
||||
block_dQ.reset(lsize(shape_QO));
|
||||
block_dK.reset(lsize(shape_KV));
|
||||
block_dV.reset(lsize(shape_KV));
|
||||
block_dO.reset(lsize(shape_QO));
|
||||
|
||||
block_ref_dQ.reset(lsize(shape_QO));
|
||||
block_ref_dK.reset(lsize(shape_KV));
|
||||
block_ref_dV.reset(lsize(shape_KV));
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
initialize_block(block_dO, seed + 2020, options.init_style_do);
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
if (! options.skip_reference) {
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
}
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
auto problem_shape = make_shape(options.q, options.k, options.d, make_shape(options.h, options.b));
|
||||
|
||||
initialize(problem_shape, options);
|
||||
|
||||
ElementAccumulator softmax_scale = 1.0f / sqrtf(options.d);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
block_Q.get(), stride_Q,
|
||||
block_K.get(), stride_K,
|
||||
block_V.get(), stride_V,
|
||||
block_O.get(), stride_O,
|
||||
block_LSE.get(), stride_LSE,
|
||||
block_dO.get(), stride_dO,
|
||||
block_dQ.get(), stride_dQ,
|
||||
block_dK.get(), stride_dK,
|
||||
block_dV.get(), stride_dV,
|
||||
softmax_scale,
|
||||
hw_info
|
||||
};
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 10.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(get<0>(problem_shape));
|
||||
flops *= static_cast<double>(get<1>(problem_shape));
|
||||
flops *= static_cast<double>(get<2>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,0>(problem_shape));
|
||||
flops *= static_cast<double>(get<3,1>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms << "ms, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct KernelCoop {};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_bwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) {
|
||||
BwdRunner<decltype(shape), decltype(kernel), Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
run(Shape<_128, _128, HeadDim>{}, KernelCoop{}, "tma");
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability 100a) and CUDA 12.8 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "Backward" << " " << (options.causal ? "Causal" : "Full") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
auto with_causal = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
}
|
||||
else {
|
||||
fn(NoMask{});
|
||||
}
|
||||
};
|
||||
|
||||
with_causal([&](auto fusion) {
|
||||
if (options.d <= 64) {
|
||||
run_bwd_64(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 128) {
|
||||
run_bwd_128(fusion, options, hw_info);
|
||||
}
|
||||
else {
|
||||
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
831
examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu
Normal file
831
examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu
Normal file
@ -0,0 +1,831 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example implementation of fused multi-head attention for the NVIDIA Blackwell SM100
|
||||
architecture using CUTLASS 3.
|
||||
|
||||
MQA/GQA
|
||||
-------
|
||||
|
||||
The head dimension can be represented as a tuple, where the K/V strides in the
|
||||
first dimension is zero. This has the effect of MQA or GQA.
|
||||
* MHA is (head_size:head_stride).
|
||||
* MQA is (head_size:head_stride) in Q and (head_size:_0) in K and V.
|
||||
* GQA is (grouped_heads,heads_kv):(head_stride,grouped_heads*head_stride) in Q
|
||||
and (grouped_heads,heads_kv):(0,head_stride) in K and V
|
||||
|
||||
Example usage:
|
||||
$ ./examples/77_blackell_fmha/77_blackell_fmha_gen_fp8 \
|
||||
--b=2048 --h=2048 --d=2048 --k=2048
|
||||
*/
|
||||
|
||||
#define DSHOW(x) print(#x ": "); print(x); print("\n");
|
||||
#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n");
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "reference/fmha_fwd_gen_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "device/fmha.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "collective/sm100_fmha_gen_mainloop_warpspecialized.hpp"
|
||||
#include "collective/sm100_fmha_gen_epilogue_warpspecialized.hpp"
|
||||
#include "kernel/sm100_fmha_gen_kernel_warpspecialized.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kZero, kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 1;
|
||||
int h = 1;
|
||||
int h_k = 1;
|
||||
int k = 512;
|
||||
int d = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
bool remap = false;
|
||||
bool varlen = false;
|
||||
bool cache_only = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
bool clear_cache = false;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_cache_k = InitStyle::kRandom;
|
||||
InitStyle init_style_cache_v = InitStyle::kRandom;
|
||||
InitStyle init_style_new_k = InitStyle::kRandom;
|
||||
InitStyle init_style_new_v = InitStyle::kRandom;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "0") {
|
||||
dst = InitStyle::kZero;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
cmd.get_cmd_line_argument("k", k, defaults.k);
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
remap = cmd.check_cmd_line_flag("remap");
|
||||
cache_only = cmd.check_cmd_line_flag("cache-only");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_cache_k, defaults.init_style_cache_k);
|
||||
get_init_style_argument(cmd, "init-style", init_style_cache_v, defaults.init_style_cache_v);
|
||||
get_init_style_argument(cmd, "init-style", init_style_new_k, defaults.init_style_new_k);
|
||||
get_init_style_argument(cmd, "init-style", init_style_new_v, defaults.init_style_new_v);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-cache-k", init_style_cache_k, init_style_cache_k);
|
||||
get_init_style_argument(cmd, "init-style-cache-v", init_style_cache_v, init_style_cache_v);
|
||||
get_init_style_argument(cmd, "init-style-new-k", init_style_new_k, init_style_new_k);
|
||||
get_init_style_argument(cmd, "init-style-new-v", init_style_new_v, init_style_new_v);
|
||||
|
||||
clear_cache = cmd.check_cmd_line_flag("clear-cache");
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_fmha_gen\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " fused multi-head attention forward-pass gen-phase kernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --k=<int> Sets the K extent (sampled around this length)\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --remap Enables batch index remapping\n"
|
||||
<< " --cache-only Only use data from KV cache, no reading or inserting new entry\n"
|
||||
<< " --varlen Varies sequence length between cache entries\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --clear-cache Clears the cache before benchmarking runs\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kZero: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 0);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (i % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool supported = false;
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
double tops_exp2_s = 0;
|
||||
double tbytes_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ClearCache {
|
||||
const int size = 1024 * 1024 * 1024 / 4;
|
||||
DeviceAllocation<float> data;
|
||||
bool active = false;
|
||||
|
||||
ClearCache() = default;
|
||||
|
||||
void set_active(bool the_active) {
|
||||
active = the_active;
|
||||
if (active) {
|
||||
data.reset(size);
|
||||
}
|
||||
else {
|
||||
data.reset(0);
|
||||
}
|
||||
}
|
||||
|
||||
void operator ()() {
|
||||
if (active) {
|
||||
initialize_block(data, 0x49314, InitStyle::kRandom);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class KernelType {
|
||||
UMMA_P, UMMA_I
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<KernelType kKernelType, class TileShape, class ThreadShape>
|
||||
struct ExampleRunner {
|
||||
|
||||
using Element = cutlass::float_e5m2_t;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
using ProblemShape = Shape<_1, int, int, Shape<Shape<int, int>, int>>;
|
||||
|
||||
using StrideQ = Stride<_0, _1, Stride<Stride<int, int>, int>>;
|
||||
using StrideNewK = Stride<_0, _1, Stride<Stride<_0, int>, int>>;
|
||||
using StrideCacheK = Stride<int, _1, Stride<Stride<_0, int>, int>>;
|
||||
using StrideNewV = StrideNewK;
|
||||
using StrideCacheV = StrideCacheK;
|
||||
using StrideO = StrideQ;
|
||||
|
||||
using Kernel =
|
||||
cutlass::fmha::kernel::Sm100FmhaGenKernelWarpspecialized<
|
||||
ProblemShape,
|
||||
cutlass::fmha::collective::Sm100FmhaGenMainloopWarpspecialized<
|
||||
Element, ElementAcc, ElementAcc, ElementOut,
|
||||
TileShape,
|
||||
StrideQ, StrideNewK, StrideNewV,
|
||||
StrideCacheK, StrideCacheV, StrideO
|
||||
>,
|
||||
cutlass::fmha::collective::Sm100FmhaGenEpilogueWarpspecialized<ElementOut, StrideO>,
|
||||
std::conditional_t<kKernelType == KernelType::UMMA_P,
|
||||
cutlass::fmha::kernel::PersistentTileScheduler,
|
||||
cutlass::fmha::kernel::IndividualTileScheduler
|
||||
>
|
||||
>;
|
||||
|
||||
using Operation = cutlass::fmha::device::FMHA<Kernel>;
|
||||
|
||||
StrideQ stride_q;
|
||||
StrideNewK stride_new_k;
|
||||
StrideNewV stride_new_v;
|
||||
StrideCacheK stride_cache_k;
|
||||
StrideCacheV stride_cache_v;
|
||||
StrideO stride_o;
|
||||
uint64_t seed = 0;
|
||||
|
||||
std::vector<int> seqlen_kv;
|
||||
|
||||
DeviceAllocation<int> block_seqlen_kv;
|
||||
DeviceAllocation<int> block_cache_batch_idx;
|
||||
DeviceAllocation<Element> block_q;
|
||||
DeviceAllocation<Element> block_new_k;
|
||||
DeviceAllocation<Element> block_new_v;
|
||||
DeviceAllocation<Element> block_cache_k;
|
||||
DeviceAllocation<Element> block_cache_v;
|
||||
DeviceAllocation<ElementOut> block_o;
|
||||
|
||||
DeviceAllocation<Element> block_ref_cache_k;
|
||||
DeviceAllocation<Element> block_ref_cache_v;
|
||||
DeviceAllocation<ElementOut> block_ref_o;
|
||||
|
||||
ClearCache clear_cache;
|
||||
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_q.get()), select<0,2,3>(problem_shape), stride_q);
|
||||
Tensor mNewK = make_tensor(make_gmem_ptr(block_new_k.get()), select<0,2,3>(problem_shape), stride_new_k);
|
||||
Tensor mNewV = make_tensor(make_gmem_ptr(block_new_v.get()), select<0,2,3>(problem_shape), stride_new_v);
|
||||
Tensor mCacheK = make_tensor(make_gmem_ptr(block_ref_cache_k.get()), select<1,2,3>(problem_shape), stride_cache_k);
|
||||
Tensor mCacheV = make_tensor(make_gmem_ptr(block_ref_cache_v.get()), select<1,2,3>(problem_shape), stride_cache_v);
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_o.get()), select<0,2,3>(problem_shape), stride_o);
|
||||
|
||||
fmha_fwd_gen_reference<ElementAcc>(
|
||||
problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(),
|
||||
mQ, mNewK, mNewV, mCacheK, mCacheV, mO);
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff);
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff);
|
||||
bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_K) {
|
||||
std::cerr << "failed Cache K: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff);
|
||||
bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_V) {
|
||||
std::cerr << "failed Cache V: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
return passed_O && passed_K && passed_V;
|
||||
}
|
||||
|
||||
ProblemShape initialize(const Options& options) {
|
||||
|
||||
clear_cache.set_active(options.clear_cache);
|
||||
|
||||
std::vector<int> cache_batch_idx;
|
||||
|
||||
// set up stides and sizes
|
||||
if (options.remap) {
|
||||
for (int i = 0; i < options.b; i++) {
|
||||
cache_batch_idx.push_back(i);
|
||||
}
|
||||
std::mt19937 rng(0x202305291305ull);
|
||||
std::shuffle(cache_batch_idx.begin(), cache_batch_idx.end(), rng);
|
||||
}
|
||||
|
||||
seqlen_kv = std::vector<int>(options.b, options.k);
|
||||
if (options.varlen) {
|
||||
std::mt19937 rng(0x202305151552ull);
|
||||
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
|
||||
|
||||
auto generate_positive_int = [](auto& dist, auto& gen) {
|
||||
int result = 0;
|
||||
do {
|
||||
result = static_cast<int>(dist(gen));
|
||||
} while (result <= 0);
|
||||
return result;
|
||||
};
|
||||
|
||||
for (int i = 0; i < options.b; i++) {
|
||||
seqlen_kv[i] = generate_positive_int(dist_kv, rng);
|
||||
}
|
||||
}
|
||||
|
||||
int max_seqlen_kv = 0;
|
||||
for (auto e : seqlen_kv) {
|
||||
max_seqlen_kv = std::max(e, max_seqlen_kv);
|
||||
}
|
||||
|
||||
ProblemShape result = make_shape(_1{}, max_seqlen_kv + 1, options.d, make_shape(make_shape(options.h / options.h_k, options.h_k), options.b));
|
||||
|
||||
stride_q = make_stride(_0{}, _1{}, make_stride(make_stride(options.d, options.d * size<3,0,0>(result)), options.d * size<3,0>(result)));
|
||||
stride_new_k = make_stride(_0{}, _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result)));
|
||||
stride_cache_k = make_stride(options.d * size<3,0,1>(result), _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result) * get<1>(result)));
|
||||
|
||||
stride_new_v = stride_new_k;
|
||||
stride_cache_v = stride_cache_k;
|
||||
stride_o = stride_q;
|
||||
|
||||
block_q.reset(options.b * get<2,1>(stride_q));
|
||||
if (! options.cache_only) {
|
||||
block_new_k.reset(options.b * get<2,1>(stride_new_k));
|
||||
block_new_v.reset(options.b * get<2,1>(stride_new_v));
|
||||
}
|
||||
block_cache_k.reset(options.b * get<2,1>(stride_cache_k));
|
||||
block_cache_v.reset(options.b * get<2,1>(stride_cache_v));
|
||||
block_o.reset(options.b * get<2,1>(stride_o));
|
||||
|
||||
block_ref_cache_k.reset(options.b * get<2,1>(stride_cache_k));
|
||||
block_ref_cache_v.reset(options.b * get<2,1>(stride_cache_v));
|
||||
block_ref_o.reset(options.b * get<2,1>(stride_o));
|
||||
|
||||
initialize_block(block_q, seed + 2023, options.init_style_q);
|
||||
if (! options.cache_only) {
|
||||
initialize_block(block_new_k, seed + 2022, options.init_style_new_k);
|
||||
initialize_block(block_new_v, seed + 2021, options.init_style_new_v);
|
||||
}
|
||||
|
||||
initialize_block(block_cache_k, seed + 2024 - 2025, options.init_style_cache_k);
|
||||
initialize_block(block_cache_v, seed + 2025, options.init_style_cache_v);
|
||||
|
||||
block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size());
|
||||
block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size());
|
||||
block_seqlen_kv.reset(seqlen_kv.size());
|
||||
block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size());
|
||||
|
||||
if (! cache_batch_idx.empty()) {
|
||||
block_cache_batch_idx.reset(cache_batch_idx.size());
|
||||
block_cache_batch_idx.copy_from_host(cache_batch_idx.data(), cache_batch_idx.size());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
auto problem_shape = initialize(options);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
block_seqlen_kv.get(), block_cache_batch_idx.get(),
|
||||
block_q.get(), stride_q,
|
||||
block_new_k.get(), stride_new_k,
|
||||
block_new_v.get(), stride_new_v,
|
||||
block_cache_k.get(), stride_cache_k,
|
||||
block_cache_v.get(), stride_cache_v,
|
||||
block_o.get(), stride_o,
|
||||
hw_info
|
||||
};
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
// std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
// << cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
example_result.supported = true;
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
float total_runtime_ms = 0;
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
|
||||
clear_cache();
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaDeviceSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
total_runtime_ms += runtime_ms;
|
||||
|
||||
}
|
||||
|
||||
float runtime_ms = total_runtime_ms / static_cast<float>(options.iterations);
|
||||
|
||||
double bytes;
|
||||
bytes = 0.0;
|
||||
bytes += double(sizeof(Element) * size<3>(problem_shape)); // Q
|
||||
bytes += double(sizeof(ElementOut) * size<3>(problem_shape)); // O
|
||||
bytes += 2.0 * double(sizeof(Element) * size<3>(problem_shape) / size<3,0,0>(problem_shape)); // NewK, NewV
|
||||
double total_seqlen_kv = 0;
|
||||
for (auto e : seqlen_kv) {
|
||||
total_seqlen_kv += double(e + 1);
|
||||
}
|
||||
bytes += 2.0 * double(sizeof(Element) * size<3,0,1>(problem_shape) * total_seqlen_kv); // CacheK, CacheV
|
||||
bytes *= static_cast<double>(size<2>(problem_shape));
|
||||
double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tbytes_s = tbytes_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms << "ms, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major < 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "Gen" << " " << (options.varlen ? "Variable" : "Uniform") << " " << (options.remap ? "Remap" : "Linear") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
using UMMA = true_type;
|
||||
using FFMA2 = false_type;
|
||||
auto run = [&](const char* name, auto kernel_type, auto tile, auto thr) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
ExampleRunner<decltype(kernel_type)::value, decltype(tile), decltype(thr)> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
|
||||
#define RUN(MODE, m, n, k, tm, tn, tk) \
|
||||
run( \
|
||||
#MODE " " #m "x" #n "x" #k " / " #tm "x" #tn "x" #tk, \
|
||||
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
|
||||
)
|
||||
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
832
examples/77_blackwell_fmha/77_blackwell_mla.cu
Normal file
832
examples/77_blackwell_fmha/77_blackwell_mla.cu
Normal file
@ -0,0 +1,832 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file A MLA (Multi-Head Latent Attention) inference kernel sample for the
|
||||
NVIDIA Blackwell Architecture.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <cmath>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "reference/fmha_mla_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 1;
|
||||
int k = 256;
|
||||
int split_kv = -1; // number of split along k dim.
|
||||
bool is_var_split_kv = false;
|
||||
int max_split_kv = 16;
|
||||
int page = -1;
|
||||
float spread = 0.2f;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_c = InitStyle::kRandom;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
if (k == -1) k = defaults.k;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("split_kv", split_kv, defaults.split_kv);
|
||||
cmd.get_cmd_line_argument("page", page, defaults.page);
|
||||
cmd.get_cmd_line_argument("spread", spread, defaults.spread);
|
||||
cmd.get_cmd_line_argument("is_var_split_kv", is_var_split_kv, false);
|
||||
if (page == -1) {
|
||||
is_var_split_kv = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("max_split_kv", max_split_kv, defaults.max_split_kv);
|
||||
if (is_var_split_kv == true) {
|
||||
split_kv = max_split_kv;
|
||||
}
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_c, defaults.init_style_c);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-c", init_style_c, init_style_c);
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_mla\n\n"
|
||||
<< " This example showcases the use of CUTLASS for fused multi-head latent\n"
|
||||
<< " attention kernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --page=<int> Enables paging and sets the page size\n"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --spread=<float> Relative spread away from K for paging\n"
|
||||
<< " --split_kv=<int> Split KV factor\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 64; i ++) {
|
||||
for (int j = 0; j < 64; j++) {
|
||||
data[j + 64*i] = static_cast<Element>((double) (i % 9));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
double tbytes_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class TileShape,
|
||||
class PersistenceOption = IsPersistent<true>
|
||||
>
|
||||
struct Runner {
|
||||
|
||||
#ifdef FP8
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
#elif FP16
|
||||
using Element = cutlass::half_t;
|
||||
#else
|
||||
#error "Must either define FP8 or FP16"
|
||||
#endif
|
||||
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler = std::conditional_t<
|
||||
PersistenceOption::value,
|
||||
Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler
|
||||
>;
|
||||
|
||||
using Kernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler
|
||||
>;
|
||||
using Operation = cutlass::fmha::device::MLA<Kernel>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideQ stride_Q_latent;
|
||||
StrideK stride_C_latent;
|
||||
StrideQ stride_Q_rope;
|
||||
StrideK stride_K_rope;
|
||||
StrideO stride_O;
|
||||
StrideLSE stride_LSE;
|
||||
StrideLSE stride_PT;
|
||||
|
||||
uint64_t seed = 0;
|
||||
|
||||
int page_size = -1;
|
||||
int page_count = -1;
|
||||
|
||||
// We allocate Q and C as first latent, then rope
|
||||
// This means that we offset the pointer by HeadDim_latent to get the rope
|
||||
// portion
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_C;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<int> block_seq;
|
||||
DeviceAllocation<int> block_PT;
|
||||
DeviceAllocation<int> block_split_kv;
|
||||
DeviceAllocation<int> block_accum_split_len;
|
||||
DeviceAllocation<ElementAcc> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAcc> block_ref_LSE;
|
||||
|
||||
ElementAcc scale;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
int page_K = K;
|
||||
int page_B = B;
|
||||
if (block_PT.get() != nullptr) {
|
||||
page_K = page_size;
|
||||
page_B = page_count;
|
||||
}
|
||||
|
||||
Tensor mQ_latent = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
cute::make_tuple(H, D_latent, B),
|
||||
stride_Q_latent);
|
||||
|
||||
Tensor mQ_rope = make_tensor(make_gmem_ptr(block_Q.get() + D_latent),
|
||||
cute::make_tuple(H, D_rope, B),
|
||||
stride_Q_rope);
|
||||
|
||||
Tensor mC_latent = make_tensor(make_gmem_ptr(block_C.get()),
|
||||
cute::make_tuple(page_K, D_latent, page_B),
|
||||
stride_C_latent);
|
||||
|
||||
Tensor mK_rope = make_tensor(make_gmem_ptr(block_C.get() + D_latent),
|
||||
cute::make_tuple(page_K, D_rope, page_B),
|
||||
stride_K_rope);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
cute::make_tuple(H, D_latent, B),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
cute::make_tuple(H, B),
|
||||
stride_LSE);
|
||||
|
||||
Tensor mSeq = make_tensor(make_gmem_ptr(static_cast<int*>(block_seq.get())), make_shape(B));
|
||||
Tensor mPT = make_tensor(make_gmem_ptr(static_cast<int*>(block_PT.get())), make_shape(ceil_div(K, page_size), B), stride_PT);
|
||||
|
||||
fmha_mla_reference(problem_shape, mSeq, mPT, mQ_latent, mQ_rope, mC_latent, mK_rope, mO, mLSE, scale);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
#ifdef B2B
|
||||
reference_rel_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#else
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
#endif
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
bool passed_LSE = true;
|
||||
#ifndef B2B
|
||||
reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_LSE) {
|
||||
std::cerr << "failed LSE: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
|
||||
ProblemShape initialize(const Options& options) {
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, options.k, TileShapeD{}, options.b);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
// the scale is based on the non-absorbed sizes, change as appropriate
|
||||
// we can't determine this parameter from the info we have, it's an input
|
||||
int D_non_latent = 128;
|
||||
scale = static_cast<decltype(scale)>(1.0 / sqrt(1.0 * (D_non_latent + D_rope)));
|
||||
// Shape (H, D, B)
|
||||
stride_Q_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(H * (0 + D_latent + D_rope)));
|
||||
stride_Q_rope = stride_Q_latent;
|
||||
stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
|
||||
block_Q.reset(static_cast<size_t>(options.b) * H * (D_latent + D_rope));
|
||||
block_O.reset(static_cast<size_t>(options.b) * H * D_latent);
|
||||
block_LSE.reset(static_cast<size_t>(options.b) * H);
|
||||
block_ref_O.reset(static_cast<size_t>(options.b) * H * D_latent);
|
||||
block_ref_LSE.reset(static_cast<size_t>(options.b) * H);
|
||||
|
||||
if (options.page == -1) {
|
||||
|
||||
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(options.k) * (D_latent + D_rope));
|
||||
stride_K_rope = stride_C_latent;
|
||||
|
||||
block_C.reset(static_cast<size_t>(options.b) * options.k * (D_latent + D_rope));
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
float spread = options.spread;
|
||||
int max_K = static_cast<int>((1 + spread) * K);
|
||||
int min_K = static_cast<int>((1 - spread) * K);
|
||||
page_size = options.page;
|
||||
page_count = B * ceil_div(max_K, page_size);
|
||||
stride_PT = cute::make_stride(_1{}, page_count);
|
||||
|
||||
std::vector<int> host_seq(B);
|
||||
std::vector<int> host_PT(page_count * B);
|
||||
|
||||
for (int i = 0; i < B; i++) {
|
||||
int seq = min_K + rand() % (max_K - min_K + 1);
|
||||
host_seq[i] = seq;
|
||||
for (int j = 0; j < ceil_div(seq, page_size); j++) {
|
||||
host_PT[page_count * i + j] = i + j * B;
|
||||
}
|
||||
}
|
||||
|
||||
block_seq.reset(host_seq.size());
|
||||
block_seq.copy_from_host(host_seq.data(), host_seq.size());
|
||||
block_PT.reset(host_PT.size());
|
||||
block_PT.copy_from_host(host_PT.data(), host_PT.size());
|
||||
|
||||
get<1>(problem_shape) = max_K;
|
||||
|
||||
stride_C_latent = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, page_size * static_cast<int64_t>((D_latent + D_rope)));
|
||||
stride_K_rope = stride_C_latent;
|
||||
|
||||
block_C.reset(page_count * page_size * static_cast<int64_t>((D_latent + D_rope)));
|
||||
|
||||
if (options.is_var_split_kv == true) {
|
||||
std::vector<int> host_split_kv(B);
|
||||
for(int i = 0; i < B; ++i) {
|
||||
auto len = host_seq[i];
|
||||
int split = ceil_div(options.max_split_kv, ceil_div(max_K, len));
|
||||
host_split_kv[i] = split;
|
||||
}
|
||||
block_split_kv.reset(B);
|
||||
block_split_kv.copy_from_host(host_split_kv.data(), host_split_kv.size());
|
||||
}
|
||||
}
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_C, seed + 2022, options.init_style_c);
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShape problem_shape = initialize(options);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ scale,
|
||||
block_Q.get(), stride_Q_latent,
|
||||
block_Q.get() + D_latent, stride_Q_rope,
|
||||
block_C.get(), stride_C_latent,
|
||||
block_C.get() + D_latent, stride_K_rope,
|
||||
block_seq.get(),
|
||||
block_PT.get(), stride_PT,
|
||||
page_count, page_size},
|
||||
{ block_O.get(),
|
||||
stride_O,
|
||||
block_LSE.get(),
|
||||
stride_LSE},
|
||||
hw_info,
|
||||
options.split_kv,
|
||||
options.is_var_split_kv ? block_split_kv.get() : nullptr
|
||||
};
|
||||
if (options.split_kv < 0 && !options.is_var_split_kv) {
|
||||
Operation::set_split_kv(arguments);
|
||||
}
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops = 1.0;
|
||||
flops *= B;
|
||||
flops *= K;
|
||||
flops *= H;
|
||||
flops *= 2.0;
|
||||
flops *= (2.0 * D_latent + D_rope);
|
||||
|
||||
double bytes_q = sizeof(Element);
|
||||
bytes_q *= B;
|
||||
bytes_q *= H;
|
||||
bytes_q *= (D_latent + D_rope);
|
||||
double bytes_c = sizeof(Element);
|
||||
bytes_c *= B;
|
||||
bytes_c *= options.k; // K may be max_K here
|
||||
bytes_c *= (D_latent + D_rope);
|
||||
double bytes_o = sizeof(ElementOut);
|
||||
bytes_o *= B;
|
||||
bytes_o *= H;
|
||||
bytes_o *= D_latent;
|
||||
double bytes = bytes_q + bytes_c + bytes_o;
|
||||
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.tbytes_s = tbytes_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s " << result.tbytes_s << " TB/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms * 1e3 << " us, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_mla(Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
Runner<decltype(shape), decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
using NumHeads = _128;
|
||||
using HeadDimLatent = _512;
|
||||
using HeadDim = Shape<HeadDimLatent, _64>;
|
||||
|
||||
std::cout << "###### B " << options.b << " MLA H " << 0 + NumHeads{} << " ";
|
||||
std::cout << "D_rope " << 0 + get<1>(HeadDim{}) << " D_latent " << 0 + get<0>(HeadDim{}) << " ";
|
||||
std::cout << "Q 1 K " << options.k << " Gen None ";
|
||||
std::cout << "Split " << options.split_kv << " Gen None ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
using Blocking = _128;
|
||||
std::string name = std::to_string((int) NumHeads{}) + "x" + std::to_string((int) Blocking{});
|
||||
std::string individual = " individual";
|
||||
std::string persistent = " persistent";
|
||||
#if FP8
|
||||
name += " fp8";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
#elif FP16
|
||||
name += " fp16";
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + persistent).c_str(), IsPersistent<true>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<NumHeads, Blocking, HeadDim>{}, (name + individual).c_str(), IsPersistent<false>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability major 10) and CUDA 12.8 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
run_mla(options, hw_info);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
147
examples/77_blackwell_fmha/CMakeLists.txt
Normal file
147
examples/77_blackwell_fmha/CMakeLists.txt
Normal file
@ -0,0 +1,147 @@
|
||||
# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set_property(
|
||||
SOURCE
|
||||
77_blackwell_fmha.cu
|
||||
77_blackwell_fmha_gen.cu
|
||||
77_blackwell_mla.cu
|
||||
77_blackwell_fmha_bwd.cu
|
||||
PROPERTY
|
||||
COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0"
|
||||
)
|
||||
|
||||
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
|
||||
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
|
||||
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
|
||||
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
set(TEST_MLA_BASIC --b=1 --k=512 --verify)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC_ARCHS MATCHES 100a))
|
||||
|
||||
foreach(PREC fp8 fp16)
|
||||
string(TOUPPER "${PREC}" PREC_MACRO)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_${PREC}
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_${PREC}
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_mla_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_2sm_cpasync_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC)
|
||||
target_compile_options(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_mla_b2b_2sm_${PREC}
|
||||
77_blackwell_mla.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_MLA_BASIC
|
||||
)
|
||||
target_include_directories(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE ${PREC_MACRO} B2B)
|
||||
target_compile_options(77_blackwell_mla_b2b_2sm_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO})
|
||||
target_compile_options(77_blackwell_fmha_bwd_${PREC} PRIVATE -Xptxas -v)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_bwd_sat_${PREC}
|
||||
77_blackwell_fmha_bwd.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||
endforeach()
|
||||
endif()
|
||||
88
examples/77_blackwell_fmha/README.md
Normal file
88
examples/77_blackwell_fmha/README.md
Normal file
@ -0,0 +1,88 @@
|
||||
# FMHA for Blackwell: Forward
|
||||
|
||||
This sample provides code for fused multi-head attention forward, context, or generation phase.
|
||||
It supports HeadDims of 32, 64, and 128, and fp8, fp16, and bf16 input data types.
|
||||
|
||||
For forward or context usage, use an M-blocking (Seqlen-Q) of 256 and an N-blocking (Seqlen-K) of 128.
|
||||
For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit is currently 32 for actual Num-Groups), and a N-blocking (Seqlen-K) of 64, 128 or 256.
|
||||
|
||||
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
|
||||
|
||||
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
|
||||
|
||||
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
|
||||
The kernel and collective layer are then formulated to be fmha-specific.
|
||||
The design assigns two tiles to each threadblock, and pingpongs between them in terms of matrix-matrix multiplication and softmax.
|
||||
|
||||
The example builds four binaries, showcasing the context and generation usage for fp8 and fp16.
|
||||
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
|
||||
|
||||
To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easiest customization point.
|
||||
The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements.
|
||||
It is well-suited for applying masks or activations.
|
||||
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
|
||||
|
||||
# FMHA for Blackwell: Backward
|
||||
|
||||
This sample provides code for fused multi-head attention backward pass.
|
||||
It supports HeadDims of 64 and 128, and fp8, fp16, and bf16 input data types.
|
||||
The blocking in sequence length Q and K is 128, loads are done via TMA.
|
||||
We support causal masking.
|
||||
The structure of this code is very similar to the forward pass, and the techniques are analogous.
|
||||
|
||||
There are three kernels to compute backwards:
|
||||
1. `FmhaKernelBwdSumOdO` to compute the sum of the outer product of O and dO.
|
||||
3. `Sm100FmhaBwdKernelTmaWarpSpecialized` to compute the backward pass.
|
||||
2. `FmhaKernelBwdConvert` to convert the dQ from fp32 to the final output precision.
|
||||
|
||||
`Sm100FmhaBwdKernelTmaWarpSpecialized` is the main point of this sample, as it demonstrates how to use tensor cores to achieve a high performance fused kernel.
|
||||
|
||||
# MLA Inference for Blackwell
|
||||
|
||||
This sample provides code for fused multi-head latent attention inference in
|
||||
the weight-absorbed regime, i.e. for latent head dim 512, and rope head dim 64.
|
||||
It supports fp16, bf16, and fp8 input and output types.
|
||||
|
||||
To accomodate the large output accumulator due to the large latent head dimension,
|
||||
the sample demonstrates how to leverage 2Sm Blackwell tensor cores.
|
||||
|
||||
Loading can be done via TMA (either without paging or with page size 128), or using `cp.async`
|
||||
for support of any power-of-two page size less than or equal to 128.
|
||||
With paging, the code also supports variable sequence length.
|
||||
|
||||
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an MLA kernel.
|
||||
|
||||
The example builds six binaries, showcasing TMA and `cp.async` usage, as well as a back-to-back gemm (essentially turning the softmax into a no-op) for fp8 and fp16.
|
||||
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
127
examples/77_blackwell_fmha/collective/fmha_common.hpp
Normal file
127
examples/77_blackwell_fmha/collective/fmha_common.hpp
Normal file
@ -0,0 +1,127 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
constexpr int rA = decltype(rank(tA))::value;
|
||||
constexpr int rB = decltype(rank(tB))::value;
|
||||
constexpr int rC = decltype(rank(tC))::value;
|
||||
static_assert(rA == 3 && rB == 3 && rC == 3);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
|
||||
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
|
||||
atom.accumulate_ = decltype(atom.accumulate_)::One;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
atom.accumulate_ = decltype(atom.accumulate_)::Zero;
|
||||
gemm_reset_zero_acc(atom, tA, tB, tC);
|
||||
}
|
||||
|
||||
template<class Layout, class Stages = _1>
|
||||
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
|
||||
return composition(layout, prepend<decltype(rank(layout))::value>(make_layout(stages), _));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_DEVICE T warp_uniform(T a) {
|
||||
return __shfl_sync(0xffffffff, a, 0);
|
||||
}
|
||||
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_tiled_mma_sm100_ts(
|
||||
TiledMMA<MMA_Atom<
|
||||
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
|
||||
cute::C<M>, cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, a_major>,
|
||||
cute::integral_constant<UMMA::Major, b_major>,
|
||||
cute::integral_constant<UMMA::ScaleIn, a_neg>,
|
||||
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
|
||||
TAs...>, TMs...>) {
|
||||
|
||||
return TiledMMA<MMA_Atom<
|
||||
MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major, b_major,
|
||||
a_neg, b_neg, UMMA::Saturate::False>>,
|
||||
TAs...>, TMs...>{};
|
||||
}
|
||||
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_tiled_mma_sm100_ts(
|
||||
TiledMMA<MMA_Atom<
|
||||
SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major,
|
||||
b_major,
|
||||
a_neg,
|
||||
b_neg>,
|
||||
TAs...>, TMs...>) {
|
||||
return TiledMMA<MMA_Atom<
|
||||
SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major, b_major,
|
||||
a_neg, b_neg, UMMA::Saturate::False>,
|
||||
TAs...>, TMs...>{};
|
||||
}
|
||||
|
||||
template<uint32_t RegCount>
|
||||
CUTLASS_DEVICE
|
||||
void warpgroup_reg_set() {
|
||||
if constexpr (RegCount < 128) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<RegCount>();
|
||||
}
|
||||
else {
|
||||
cutlass::arch::warpgroup_reg_alloc<RegCount>();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
254
examples/77_blackwell_fmha/collective/fmha_fusion.hpp
Normal file
254
examples/77_blackwell_fmha/collective/fmha_fusion.hpp
Normal file
@ -0,0 +1,254 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct NoMask {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return ceil_div(get<1>(problem_size), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
struct ResidualMask : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
template <class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// if the sequence length does not divide the tile size evenly
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
||||
}
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
||||
// the remaining elements out from softmax.
|
||||
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
||||
// issues as they are transparently taken care of by TMA and the
|
||||
// epilogue, if it is instantiated with predication support.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<1>(pos) >= get<1>(problem_size)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CausalMask : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) is to assume that the Q is at the beginning of the matrix
|
||||
// - this is what we demonstrate here
|
||||
// (2) is that it is at the end of the matrix
|
||||
// - this is usually what we want for inference settings
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct VariableLength {
|
||||
int max_length;
|
||||
int* cumulative_length = nullptr;
|
||||
|
||||
CUTE_HOST_DEVICE operator int() const {
|
||||
return max_length;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T> struct is_variable_length : std::false_type {};
|
||||
template<> struct is_variable_length<VariableLength> : std::true_type {};
|
||||
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
|
||||
|
||||
template<class Shape, class Idx>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length(Shape const& shape, Idx const& idx) {
|
||||
return transform_leaf(shape, [&](auto const& s) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
return s;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<class Shape, class Coord, class Idx>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
||||
auto new_shape = apply_variable_length(shape, idx);
|
||||
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
return cute::make_tuple(c, s.cumulative_length[idx]);
|
||||
}
|
||||
else {
|
||||
return c;
|
||||
}
|
||||
});
|
||||
return cute::make_tuple(new_shape, new_coord);
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
namespace cute {
|
||||
|
||||
template<>
|
||||
struct is_integral<cutlass::fmha::collective::VariableLength> : true_type {};
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void print(cutlass::fmha::collective::VariableLength a) {
|
||||
printf("Varlen<%d, %p>", a.max_length, a.cumulative_length);
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,200 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAcc,
|
||||
class TileShape, // Q, D, _
|
||||
class StrideO, // Q, D, B
|
||||
class StrideLSE // Q, B
|
||||
>
|
||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
using Pipeline = cutlass::PipelineAsync<2>;
|
||||
|
||||
// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{})));
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>());
|
||||
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
using SmemLayoutO = SmemLayoutO_;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
Element* ptr_O;
|
||||
StrideO dO;
|
||||
|
||||
ElementAcc* ptr_LSE;
|
||||
StrideLSE dLSE;
|
||||
};
|
||||
|
||||
using TMA_O = decltype(make_tma_copy(
|
||||
SM90_TMA_STORE{},
|
||||
make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}),
|
||||
SmemLayoutO{}(_,_,_0{})
|
||||
));
|
||||
|
||||
|
||||
struct Params {
|
||||
TMA_O tma_store_o;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace = nullptr) {
|
||||
|
||||
auto ptr_O = args.ptr_O;
|
||||
StrideO dO = args.dO;
|
||||
auto problem_shape_O = select<0,2,3>(problem_shape);
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dO) = get<0>(dO);
|
||||
get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_O -= max_length_q * get<0>(dO);
|
||||
}
|
||||
}
|
||||
|
||||
auto tma_store_o = make_tma_copy(
|
||||
SM90_TMA_STORE{},
|
||||
make_tensor(ptr_O, problem_shape_O, dO),
|
||||
SmemLayoutO{}(_,_,_0{})
|
||||
);
|
||||
|
||||
return {
|
||||
tma_store_o
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
store(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& shared_storage,
|
||||
Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {
|
||||
|
||||
BlkCoord blk_coord = blk_coord_in;
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
int o0_index = 2 * get<0>(blk_coord);
|
||||
int o1_index = 2 * get<0>(blk_coord) + 1;
|
||||
|
||||
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
// offset mode 0 by (max_length - real_length)
|
||||
// offset mode 3,1 by cumulative_length + real_length
|
||||
// the ptr is already offset by - max_length
|
||||
// so in total this achieves
|
||||
int offs_0 = 0;
|
||||
int offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
offs_0 = max_length_q - get<0>(problem_shape);
|
||||
offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape);
|
||||
get<2,1>(blk_coord) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p);
|
||||
|
||||
Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{});
|
||||
Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord));
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto block_tma = params.tma_store_o.get_slice(0);
|
||||
Tensor tOsO = block_tma.partition_S(sO);
|
||||
Tensor tOgO = block_tma.partition_D(gO);
|
||||
|
||||
auto pipeline_release_state = pipeline_consumer_state;
|
||||
|
||||
// O1 O2
|
||||
// one pipeline: O
|
||||
// wait from corr, issue tma store on smem
|
||||
pipeline.consumer_wait(pipeline_consumer_state);
|
||||
++pipeline_consumer_state;
|
||||
|
||||
if (lane_predicate) {
|
||||
copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index));
|
||||
}
|
||||
tma_store_arrive();
|
||||
|
||||
pipeline.consumer_wait(pipeline_consumer_state);
|
||||
++pipeline_consumer_state;
|
||||
|
||||
if (lane_predicate) {
|
||||
copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index));
|
||||
}
|
||||
tma_store_arrive();
|
||||
|
||||
tma_store_wait<1>();
|
||||
|
||||
pipeline.consumer_release(pipeline_release_state);
|
||||
++pipeline_release_state;
|
||||
|
||||
tma_store_wait<0>();
|
||||
|
||||
pipeline.consumer_release(pipeline_release_state);
|
||||
++pipeline_release_state;
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
/******************************************************************************
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
@ -27,58 +27,68 @@
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief generic device-to-device data movement kernel based for CuTe tensors.
|
||||
|
||||
NOTE: this kernel assigns one element copy to every thread, and is by no means
|
||||
an efficient way of copying tensors. It should only be used for convenience in
|
||||
reference checks.
|
||||
|
||||
*/
|
||||
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream);
|
||||
template<
|
||||
class Element_,
|
||||
class StrideO_
|
||||
>
|
||||
struct Sm100FmhaGenEpilogueWarpspecialized {
|
||||
|
||||
using Pipeline = cutlass::PipelineAsync<2>;
|
||||
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
__global__ void device_copy_kernel(TensorSource const tensor_source,
|
||||
TensorDestination tensor_destination) {
|
||||
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
using ElementSrc = typename TensorSource::value_type;
|
||||
using ElementDst = typename TensorDestination::value_type;
|
||||
NumericConverter<ElementDst, ElementSrc> converter;
|
||||
if (linear_idx < size(tensor_source)) {
|
||||
tensor_destination(linear_idx) = converter(tensor_source(linear_idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorSource, typename TensorDestination>
|
||||
void device_copy(TensorSource tensor_source,
|
||||
TensorDestination tensor_destination,
|
||||
cudaStream_t stream) {
|
||||
using SmemLayoutO = Layout<Shape<_1, _1, _1>>;
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using Element = Element_;
|
||||
using StrideOOrig = StrideO_;
|
||||
using StrideO = decltype(replace<0>(StrideOOrig{}, 0));
|
||||
|
||||
assert(tensor_source.size() == tensor_destination.size());
|
||||
struct TensorStorage {
|
||||
|
||||
auto numel = tensor_source.size();
|
||||
static constexpr int NumThreads = 128;
|
||||
auto grid_size = cute::ceil_div(numel, NumThreads);
|
||||
using SmemLayoutO = SmemLayoutO_;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
|
||||
dim3 grid(grid_size);
|
||||
dim3 block(NumThreads);
|
||||
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
|
||||
}
|
||||
};
|
||||
|
||||
} //namespace cutlass
|
||||
struct Arguments {
|
||||
Element* ptr_o;
|
||||
StrideO dO;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
const Params& params;
|
||||
|
||||
CUTLASS_DEVICE Sm100FmhaGenEpilogueWarpspecialized(const Params& params) : params(params) {}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace = nullptr) {
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
/* no-op */
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
store(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& shared_storage,
|
||||
Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {
|
||||
/* no-op */
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,384 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
#include "collective/fmha_common.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class StrideQ,
|
||||
class StrideNewK,
|
||||
class StrideNewV,
|
||||
class StrideCacheK,
|
||||
class StrideCacheV,
|
||||
class TensorStorage,
|
||||
class CollectiveMmaQK,
|
||||
class CollectiveMmaPV,
|
||||
class SmemLayoutQ,
|
||||
class SmemLayoutK,
|
||||
class SmemLayoutV,
|
||||
class PipelineQ,
|
||||
class PipelineKV,
|
||||
class TileShape,
|
||||
class Mask
|
||||
>
|
||||
struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
|
||||
using TileShapeQK = typename CollectiveMmaQK::TileShape;
|
||||
using TileShapePV = typename CollectiveMmaPV::TileShape;
|
||||
|
||||
struct Arguments {
|
||||
|
||||
const int* cache_batch_idx;
|
||||
|
||||
const Element* ptr_q;
|
||||
StrideQ dQ;
|
||||
|
||||
const Element* ptr_new_k;
|
||||
StrideNewK dNewK;
|
||||
const Element* ptr_new_v;
|
||||
StrideNewV dNewV;
|
||||
|
||||
Element* ptr_cache_k;
|
||||
StrideCacheK dCacheK;
|
||||
Element* ptr_cache_v;
|
||||
StrideCacheV dCacheV;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
}
|
||||
|
||||
template<class TEngine, class TLayout>
|
||||
CUTLASS_DEVICE auto constexpr transpose(Tensor<TEngine, TLayout> const& t) {
|
||||
CUTE_STATIC_ASSERT_V(rank(t) == _2{});
|
||||
return t.compose(make_layout(make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{})));
|
||||
}
|
||||
|
||||
template<
|
||||
class CAtom, class TA, class TB,
|
||||
class CountTensor, class CountLimit,
|
||||
class SrcTensor, class DstTensor
|
||||
>
|
||||
CUTLASS_DEVICE void copy_with_limit(
|
||||
TiledCopy<CAtom, TA, TB> const& tiled_copy,
|
||||
CountTensor const& c, CountLimit const& l,
|
||||
SrcTensor const& src, DstTensor&& dst) {
|
||||
|
||||
//copy(tiled_copy, src, dst);
|
||||
#if 1
|
||||
auto c_f = make_tensor(c.data(), flatten(c.layout()));
|
||||
auto src_f = make_tensor(src.data(), flatten(src.layout()));
|
||||
auto dst_f = make_tensor(dst.data(), flatten(dst.layout()));
|
||||
auto c_v = group_modes<1,rank_v<decltype(c_f)>>(c_f);
|
||||
auto src_v = group_modes<1,rank_v<decltype(src_f)>>(src_f);
|
||||
auto dst_v = group_modes<1,rank_v<decltype(dst_f)>>(dst_f);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(src_v); i++) {
|
||||
if (elem_less(c_v(_0{}, i), l)) {
|
||||
copy(CAtom{}, src_v(_, i), dst_v(_, i));
|
||||
}
|
||||
else {
|
||||
clear(dst_v(_, i));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
BlkCoord const& blk_coord, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& storage,
|
||||
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
|
||||
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
||||
mask_tile_count *= 2;
|
||||
|
||||
int warp_idx = (threadIdx.x / 32) % 2;
|
||||
int thread_idx = warp_idx * 32 + (threadIdx.x % 32);
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
// this one is only executed by one thread, no need to elect_one
|
||||
auto blk_coord_cache = blk_coord;
|
||||
if (params.cache_batch_idx != nullptr) {
|
||||
get<2,1>(blk_coord_cache) = params.cache_batch_idx[get<2,1>(blk_coord_cache)];
|
||||
}
|
||||
|
||||
// Q1, K1, K2, V1, K3, V2, ... Kn, Vn-1, Vn
|
||||
// two pipes: Q and KV
|
||||
auto cQ = make_identity_tensor(select<0,2>(TileShape{}));
|
||||
auto mQ = make_tensor(make_gmem_ptr(params.ptr_q), append<3>(select<0,2>(TileShapeQK{}), get<3>(problem_shape)), params.dQ);
|
||||
auto gQ = mQ(_, _, get<2>(blk_coord));
|
||||
auto sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
|
||||
typename CollectiveMmaQK::TiledMma mma_qk;
|
||||
ThrMMA thr_mma_qk = mma_qk.get_slice(0);
|
||||
auto tSgQ = thr_mma_qk.partition_A(gQ);
|
||||
auto tScQ = thr_mma_qk.partition_A(cQ);
|
||||
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
|
||||
auto tiled_copy_q = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},
|
||||
atom_q_tv,
|
||||
make_layout(shape(tSgQ), replace<0>(stride(tSgQ), replace<0>(stride<0>(tSgQ), get<2>(TileShape{})))));
|
||||
|
||||
auto thr_copy_q = tiled_copy_q.get_slice(thread_idx);
|
||||
|
||||
auto tQsQ = thr_copy_q.partition_D(sQ);
|
||||
auto tQgQ = thr_copy_q.partition_S(tSgQ);
|
||||
auto tQcQ = thr_copy_q.partition_S(tScQ);
|
||||
|
||||
auto limitQ = append<2>(get<0>(problem_shape), _128{});
|
||||
|
||||
// Q1
|
||||
int q0_index = get<0>(blk_coord);
|
||||
|
||||
auto load_q = [&](int q_index, auto& state) {
|
||||
pipeline_q.producer_acquire(state);
|
||||
|
||||
// q is always loaded masked
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
auto src = recast<Vec>(tQgQ(_, _, _, _));
|
||||
auto dst = recast<Vec>(tQsQ(_, _, _, _, state.index()));
|
||||
auto c = tQcQ(_, _, _, _);
|
||||
int vlen = sizeof(Vec) / sizeof(Element);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src); i++) {
|
||||
auto cc = c(vlen*i);
|
||||
Vec* dst_ptr = &dst(i);
|
||||
const Vec* src_ptr = &src(i);
|
||||
bool guard = elem_less(cc, limitQ);
|
||||
cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Always>(
|
||||
dst_ptr, src_ptr, guard
|
||||
);
|
||||
}
|
||||
|
||||
pipeline_q.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
};
|
||||
|
||||
load_q(q0_index, pipeline_q_producer_state);
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{}));
|
||||
auto cK = make_tensor(cK_t.data(), make_layout(get<0>(cK_t.layout()), get<1>(cK_t.layout()), make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t))));
|
||||
auto mK = make_tensor(make_gmem_ptr(params.ptr_cache_k), select<1,2,3>(problem_shape), params.dCacheK);
|
||||
auto gK = local_tile(mK(_, _, get<2>(blk_coord_cache)), TileShapeQK{}, make_coord(_, _, _0{}), Step<X, _1, _1>{});
|
||||
auto sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
|
||||
auto tSgK = thr_mma_qk.partition_B(gK);
|
||||
auto tScK = thr_mma_qk.partition_B(cK);
|
||||
|
||||
auto tSlK = thr_mma_qk.partition_B(make_tensor((Element*) nullptr, make_ordered_layout(select<1,2>(TileShapeQK{}), Step<_1, _0>{})));
|
||||
auto tiled_copy_k = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, Element>{},
|
||||
atom_kv_tv,
|
||||
tSlK.layout());
|
||||
|
||||
auto thr_copy_k = tiled_copy_k.get_slice(thread_idx);
|
||||
|
||||
auto tKsK = thr_copy_k.partition_D(sK);
|
||||
auto tKgK = thr_copy_k.partition_S(tSgK);
|
||||
auto tKcK = thr_copy_k.partition_S(tScK);
|
||||
|
||||
int seqlen_cache_kv = get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0);
|
||||
auto limitK = append<2>(seqlen_cache_kv, _128{});
|
||||
|
||||
auto cV_t = make_identity_tensor(select<1,2>(TileShapePV{}));
|
||||
auto cV = make_tensor(cV_t.data(), make_layout(get<0>(cV_t.layout()), get<1>(cV_t.layout()), make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t))));
|
||||
auto mV = make_tensor(make_gmem_ptr(params.ptr_cache_v), select<2,1,3>(problem_shape), select<1,0,2>(params.dCacheV));
|
||||
auto gV = local_tile(mV(_, _, get<2>(blk_coord_cache)), TileShapePV{}, make_coord(_, _0{}, _), Step<X, _1, _1>{});
|
||||
auto sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma_pv;
|
||||
ThrMMA thr_mma_pv = mma_pv.get_slice(0);
|
||||
auto tOgV = thr_mma_pv.partition_B(gV);
|
||||
auto tOcV = thr_mma_pv.partition_B(cV);
|
||||
auto tOlV = thr_mma_pv.partition_B(make_tensor((Element*) nullptr, make_layout(select<1,2>(TileShapePV{}))));
|
||||
|
||||
auto tiled_copy_v = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, Element>{},
|
||||
atom_kv_tv,
|
||||
tOlV.layout());
|
||||
|
||||
auto thr_copy_v = tiled_copy_v.get_slice(thread_idx);
|
||||
|
||||
auto tVsV = thr_copy_v.partition_D(sV);
|
||||
auto tVgV = thr_copy_v.partition_S(tOgV);
|
||||
auto tVcV = thr_copy_v.partition_S(tOcV);
|
||||
|
||||
auto limitV = select<1,0>(limitK);
|
||||
|
||||
int full_tiles_cache = seqlen_cache_kv / get<1>(TileShapeQK{});
|
||||
|
||||
bool has_new = params.ptr_new_k != nullptr;
|
||||
Tensor mNewK = make_tensor(make_gmem_ptr(params.ptr_new_k), select<1,2,3>(problem_shape), params.dNewK);
|
||||
Tensor mNewV = make_tensor(make_gmem_ptr(params.ptr_new_v), select<1,2,3>(problem_shape), params.dNewV);
|
||||
Tensor gNewK = mNewK(_, _, get<2>(blk_coord));
|
||||
Tensor gNewV = mNewV(_, _, get<2>(blk_coord));
|
||||
|
||||
auto load_k = [&](int k_index, auto& state) {
|
||||
pipeline_kv.producer_acquire(state);
|
||||
|
||||
if (k_index < full_tiles_cache) {
|
||||
copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index()));
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
} else {
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
auto src = recast<Vec>(tKgK(_, _, _, _, k_index));
|
||||
auto dst = recast<Vec>(tKsK(_, _, _, _, state.index()));
|
||||
auto src2 = recast<Vec>(gNewK);
|
||||
auto c = tKcK(_, _, _, _, k_index);
|
||||
int vlen = sizeof(Vec) / sizeof(Element);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src); i++) {
|
||||
auto cc = c(vlen*i);
|
||||
Vec* dst_ptr = &dst(i);
|
||||
const Vec* src_ptr = &src(i);
|
||||
bool guard = elem_less(cc, limitK);
|
||||
if (get<0>(cc) == seqlen_cache_kv && has_new) {
|
||||
src_ptr = &src2(_0{}, get<1>(cc) / vlen);
|
||||
guard = true;
|
||||
}
|
||||
cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>(
|
||||
dst_ptr, src_ptr, guard
|
||||
);
|
||||
}
|
||||
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
}
|
||||
};
|
||||
|
||||
auto load_v = [&](int v_index, auto& state) {
|
||||
pipeline_kv.producer_acquire(state);
|
||||
|
||||
if (v_index < full_tiles_cache) {
|
||||
copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index()));
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
} else {
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
auto src = recast<Vec>(tVgV(_, _, _, _, v_index));
|
||||
auto dst = recast<Vec>(tVsV(_, _, _, _, state.index()));
|
||||
auto src2 = recast<Vec>(gNewV);
|
||||
int vlen = sizeof(Vec) / sizeof(Element);
|
||||
auto c = tVcV(_, _, _, _, v_index);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src); i++) {
|
||||
auto cc = c(vlen*i);
|
||||
Vec* dst_ptr = &dst(i);
|
||||
const Vec* src_ptr = &src(i);
|
||||
bool guard = elem_less(cc, limitV);
|
||||
if (get<1>(cc) == seqlen_cache_kv && has_new) {
|
||||
src_ptr = &src2(_0{}, get<0>(cc) / vlen);
|
||||
guard = true;
|
||||
}
|
||||
cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>(
|
||||
dst_ptr, src_ptr, guard
|
||||
);
|
||||
}
|
||||
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
}
|
||||
};
|
||||
|
||||
// K1
|
||||
int k_index = 0;
|
||||
int v_index = 0;
|
||||
|
||||
load_k(k_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
mask_tile_count -= 1;
|
||||
|
||||
for (; mask_tile_count > 0; mask_tile_count -= 1) {
|
||||
|
||||
load_k(k_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
load_v(v_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
v_index += 1;
|
||||
}
|
||||
|
||||
// V1
|
||||
|
||||
load_v(v_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
v_index += 1;
|
||||
|
||||
if (has_new) {
|
||||
for (int i = thread_idx; i < get<2>(TileShape{}); i += 64) {
|
||||
gK(seqlen_cache_kv, i, 0) = gNewK(0, i);
|
||||
gV(i, seqlen_cache_kv, 0) = gNewV(0, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
@ -0,0 +1,316 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
#include "collective/fmha_common.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class StrideQ,
|
||||
class StrideK,
|
||||
class StrideV,
|
||||
class CollectiveMmaQK,
|
||||
class CollectiveMmaPV,
|
||||
class SmemLayoutQ,
|
||||
class SmemLayoutK,
|
||||
class SmemLayoutV,
|
||||
class TensorStorage,
|
||||
class PipelineQ,
|
||||
class PipelineKV,
|
||||
class Mask,
|
||||
class TileShape
|
||||
>
|
||||
struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
|
||||
using TileShapeQK = typename CollectiveMmaQK::TileShape;
|
||||
using TileShapePV = typename CollectiveMmaPV::TileShape;
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
StrideQ dQ;
|
||||
const Element* ptr_K;
|
||||
StrideK dK;
|
||||
const Element* ptr_V;
|
||||
StrideV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
|
||||
auto ptr_Q = args.ptr_Q;
|
||||
auto ptr_K = args.ptr_K;
|
||||
auto ptr_V = args.ptr_V;
|
||||
auto dQ = args.dQ;
|
||||
auto dK = args.dK;
|
||||
auto dV = args.dV;
|
||||
auto problem_shape_qk = problem_shape;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dQ) = get<0>(dQ);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_Q -= max_length_q * get<0>(dQ);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
|
||||
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_kv != nullptr) {
|
||||
int max_length_kv = get<1>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dK) = get<0>(dK);
|
||||
get<2,1>(dV) = get<0>(dV);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_K -= max_length_kv * get<0>(dK);
|
||||
ptr_V -= max_length_kv * get<0>(dV);
|
||||
}
|
||||
}
|
||||
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
|
||||
problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
ptr_Q, dQ,
|
||||
ptr_K, dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
|
||||
problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
ptr_K, dK, // never used, dummy
|
||||
ptr_V, select<1,0,2>(dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& storage,
|
||||
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
|
||||
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
|
||||
|
||||
BlkCoord blk_coord_q = blk_coord_in;
|
||||
BlkCoord blk_coord_kv = blk_coord_in;
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
// this one is only executed by one thread, no need to elect_one
|
||||
|
||||
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
|
||||
// two pipes: Q and KV
|
||||
// from Memory (prod) to TensorCore (cons)
|
||||
|
||||
// compute gQ, sQ
|
||||
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
|
||||
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
|
||||
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
|
||||
int q_offs_0 = 0;
|
||||
int q_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
q_offs_0 = max_length_q - get<0>(problem_shape);
|
||||
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
|
||||
get<2,1>(blk_coord_q) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
|
||||
|
||||
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
auto [tQgQ_qdl, tQsQ] = tma_partition(
|
||||
params.tma_load_q, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
|
||||
);
|
||||
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
|
||||
|
||||
// compute gK, sK
|
||||
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
|
||||
|
||||
int kv_offs_0 = 0;
|
||||
int kv_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
|
||||
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length != nullptr) {
|
||||
int max_length = get<1>(params_problem_shape).max_length;
|
||||
kv_offs_0 = max_length - get<1>(problem_shape);
|
||||
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
|
||||
get<2,1>(blk_coord_kv) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
|
||||
|
||||
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
auto [tKgK_kdl, tKsK] = tma_partition(
|
||||
params.tma_load_k, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
|
||||
);
|
||||
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
|
||||
|
||||
// compute gV, sV
|
||||
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
|
||||
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
|
||||
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
|
||||
|
||||
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
auto [tVgV_dkl, tVsV] = tma_partition(
|
||||
params.tma_load_v, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
|
||||
);
|
||||
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
|
||||
|
||||
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
|
||||
// As such, it needs to be transformed as
|
||||
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
|
||||
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
|
||||
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Q1
|
||||
int q0_index = 2 * get<0>(blk_coord_q);
|
||||
int q1_index = 2 * get<0>(blk_coord_q) + 1;
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
// K1
|
||||
int k_index = 0;
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Q2
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
// V1
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
// loop:
|
||||
mask_tile_count -= 1;
|
||||
for (; mask_tile_count > 0; mask_tile_count -= 1) {
|
||||
|
||||
// Ki
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Vi
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
92
examples/77_blackwell_fmha/common/pow_2.hpp
Normal file
92
examples/77_blackwell_fmha/common/pow_2.hpp
Normal file
@ -0,0 +1,92 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/numeric/integral_constant.hpp>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cutlass::fmha {
|
||||
|
||||
struct Pow2 {
|
||||
int n;
|
||||
int log2_n;
|
||||
|
||||
explicit CUTE_DEVICE Pow2(int n) : n(n) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
log2_n = __ffs(n) - 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE T operator *(T const& b) const {
|
||||
return n * b;
|
||||
}
|
||||
|
||||
template<int N>
|
||||
CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {
|
||||
if constexpr (N & (N - 1) == 0) {
|
||||
return Pow2{n * N};
|
||||
}
|
||||
return n * N;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {
|
||||
return a >> b.log2_n;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {
|
||||
return a & (b.n - 1);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
|
||||
return a < b.n;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE void print(Pow2 const& a) {
|
||||
printf("2^%d", a.log2_n);
|
||||
}
|
||||
|
||||
} // end namespace cutlass::fmha
|
||||
|
||||
namespace cute {
|
||||
|
||||
template <>
|
||||
struct is_integral<cutlass::fmha::Pow2> : true_type {};
|
||||
|
||||
} // end namespace cute
|
||||
276
examples/77_blackwell_fmha/device/fmha.hpp
Normal file
276
examples/77_blackwell_fmha/device/fmha.hpp
Normal file
@ -0,0 +1,276 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class Kernel_>
|
||||
class FMHA {
|
||||
public:
|
||||
using Kernel = Kernel_;
|
||||
|
||||
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
|
||||
|
||||
/// Argument structure: User API
|
||||
using Arguments = typename Kernel::Arguments;
|
||||
/// Argument structure: Kernel API
|
||||
using Params = typename Kernel::Params;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (Kernel::can_implement(args)) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
return Kernel::get_grid_shape(params);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FMHA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FMHA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = get_grid_shape(params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
320
examples/77_blackwell_fmha/device/fmha_device_bwd.hpp
Normal file
320
examples/77_blackwell_fmha/device/fmha_device_bwd.hpp
Normal file
@ -0,0 +1,320 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "../device/fmha.hpp"
|
||||
#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
|
||||
#include "../kernel/fmha_kernel_bwd_convert.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Element, class ElementAccumulator, class TileShape, class Mask>
|
||||
class Sm100FmhaBwd {
|
||||
public:
|
||||
/// Argument structure: User API
|
||||
struct Arguments {
|
||||
// Q K D HB
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||
|
||||
const Element* ptr_Q;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_Q;
|
||||
const Element* ptr_K;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_K;
|
||||
const Element* ptr_V;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_V;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
const ElementAccumulator* ptr_LSE;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_LSE;
|
||||
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
|
||||
Element* ptr_dQ;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dQ;
|
||||
Element* ptr_dK;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dK;
|
||||
Element* ptr_dV;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dV;
|
||||
|
||||
ElementAccumulator softmax_scale;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
using OperationSumOdO = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>
|
||||
>;
|
||||
using OperationConvert = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>
|
||||
>;
|
||||
|
||||
using Operation = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized<Element, ElementAccumulator, TileShape, Mask>
|
||||
>;
|
||||
using Kernel = typename Operation::Kernel;
|
||||
|
||||
struct Params {
|
||||
OperationSumOdO op_sum_OdO;
|
||||
Operation op;
|
||||
OperationConvert op_convert;
|
||||
ElementAccumulator* dQ_acc;
|
||||
size_t dQ_acc_size;
|
||||
};
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
|
||||
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_odo = nullptr,
|
||||
ElementAccumulator* scaled_lse = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H));
|
||||
auto log2_e = log2f(expf(1.0f));
|
||||
return typename OperationSumOdO::Arguments {
|
||||
args.problem_size,
|
||||
args.ptr_O, args.stride_O,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
sum_odo, stride_sum_OdO,
|
||||
args.ptr_LSE, args.stride_LSE,
|
||||
scaled_lse, stride_scaled_lse,
|
||||
-1.0f, -log2_e
|
||||
};
|
||||
}
|
||||
|
||||
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||
using namespace cute;
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H));
|
||||
return typename OperationConvert::Arguments {
|
||||
args.problem_size,
|
||||
src, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
nullptr, stride_src_dQ,
|
||||
args.ptr_dQ, args.stride_dQ,
|
||||
nullptr, args.stride_dK,
|
||||
nullptr, args.stride_dV,
|
||||
args.softmax_scale
|
||||
};
|
||||
}
|
||||
|
||||
static typename Operation::Arguments to_bwd_arguments(
|
||||
Arguments const& args,
|
||||
ElementAccumulator* sum_OdO = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_sum_OdO = {},
|
||||
ElementAccumulator* scaled_lse = nullptr, cute::tuple<cute::_1, cute::tuple<int, int>> const& stride_scaled_lse = {},
|
||||
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, cute::_1, cute::tuple<int, int>> const& stride_dQ = {}) {
|
||||
return typename Operation::Arguments{
|
||||
args.problem_size,
|
||||
{ args.ptr_Q, args.stride_Q,
|
||||
args.ptr_K, args.stride_K,
|
||||
args.ptr_V, args.stride_V,
|
||||
args.ptr_dO, args.stride_dO,
|
||||
scaled_lse, stride_scaled_lse,
|
||||
sum_OdO, stride_sum_OdO,
|
||||
dQ_acc, stride_dQ,
|
||||
args.softmax_scale },
|
||||
{ args.ptr_dK, args.stride_dK,
|
||||
args.ptr_dV, args.stride_dV },
|
||||
args.hw_info
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = OperationConvert::can_implement(to_convert_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = Operation::can_implement(to_bwd_arguments(args));
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
size_t workspace_bytes = 0;
|
||||
// OdO vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
// scaled LSE vector
|
||||
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||
// FP32 versions of outputs that are churned (start off with Q only)
|
||||
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||
params_.dQ_acc = dQ_acc;
|
||||
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
|
||||
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
|
||||
auto args_convert = to_convert_arguments(args, dQ_acc);
|
||||
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
|
||||
params_.op_convert.initialize(args_convert, nullptr, stream);
|
||||
auto args_bwd = to_bwd_arguments(
|
||||
args, sum_OdO, args_sum_OdO.stride_sum_OdO,
|
||||
scaled_lse, args_sum_OdO.stride_scaled_lse,
|
||||
dQ_acc, args_convert.stride_src_dQ
|
||||
);
|
||||
params_.op.initialize(args_bwd, nullptr, stream);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
auto [Q, K, D, HB] = args.problem_size;
|
||||
auto [H, B] = HB;
|
||||
D = cutlass::round_up(D, 8); // Alignment
|
||||
Q = cutlass::round_up(Q, 8); // Alignment
|
||||
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
|
||||
|
||||
Status result = Status::kSuccess;
|
||||
result = params.op_sum_OdO.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
|
||||
if (cuda_result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = params.op.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = params.op_convert.run(stream);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
357
examples/77_blackwell_fmha/device/sm100_mla.hpp
Normal file
357
examples/77_blackwell_fmha/device/sm100_mla.hpp
Normal file
@ -0,0 +1,357 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
int max_splits = ceil_div(K, 128);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
146
examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
146
examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
@ -0,0 +1,146 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAcc>
|
||||
struct FmhaKernelBwdConvert {
|
||||
|
||||
struct Arguments {
|
||||
tuple<int, int, int, tuple<int, int>> problem_size;
|
||||
|
||||
const ElementAcc* ptr_src_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dQ;
|
||||
const ElementAcc* ptr_src_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dK;
|
||||
const ElementAcc* ptr_src_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_src_dV;
|
||||
|
||||
Element* ptr_dest_dQ;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dQ;
|
||||
Element* ptr_dest_dK;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dK;
|
||||
Element* ptr_dest_dV;
|
||||
tuple<int, _1, tuple<int, int>> stride_dest_dV;
|
||||
|
||||
ElementAcc scale = 1.0;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
static const int kBlockSeq = 8;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kNumThreadsD = 16;
|
||||
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 4;
|
||||
|
||||
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(size<3,0>(params.problem_size), size<3,1>(params.problem_size), ceil_div(std::max(size<0>(params.problem_size), size<1>(params.problem_size)), kBlockSeq));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template<class StrideSrc, class StrideDest>
|
||||
CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
|
||||
auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y;
|
||||
auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y;
|
||||
|
||||
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
||||
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
||||
if (idx_s >= count) continue;
|
||||
auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src);
|
||||
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
ElementAcc value_src[kElementsPerLoad];
|
||||
Element value_dest[kElementsPerLoad];
|
||||
|
||||
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAcc> * kElementsPerLoad>;
|
||||
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
value_dest[v] = static_cast<Element>(params.scale * value_src[v]);
|
||||
}
|
||||
|
||||
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
if (params.ptr_src_dQ != nullptr) {
|
||||
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dK != nullptr) {
|
||||
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_size));
|
||||
}
|
||||
if (params.ptr_src_dV != nullptr) {
|
||||
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_size));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
151
examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
151
examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
@ -0,0 +1,151 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<class Element, class ElementAcc>
|
||||
struct FmhaKernelBwdSumOdO {
|
||||
|
||||
struct Arguments {
|
||||
cute::tuple<int, int, int, cute::tuple<int, int>> problem_size;
|
||||
|
||||
const Element* ptr_O;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_O;
|
||||
const Element* ptr_dO;
|
||||
cute::tuple<int, cute::_1, cute::tuple<int, int>> stride_dO;
|
||||
|
||||
ElementAcc* ptr_sum_OdO;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_sum_OdO;
|
||||
|
||||
const ElementAcc* ptr_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_lse;
|
||||
|
||||
ElementAcc* ptr_scaled_lse = nullptr;
|
||||
cute::tuple<cute::_1, cute::tuple<int, int>> stride_scaled_lse;
|
||||
|
||||
ElementAcc sum_odo_scale = 1.0;
|
||||
ElementAcc lse_scale = 1.0;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
static constexpr int SharedStorageSize = 0;
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static const int kBlockQ = 16;
|
||||
|
||||
static const int kNumThreadsD = 8;
|
||||
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
|
||||
static const int kElementsPerLoad = 2;
|
||||
|
||||
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return get<2>(args.problem_size) % kElementsPerLoad == 0;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(ceil_div(size<0>(params.problem_size), kBlockQ), size<3,0>(params.problem_size), size<3,1>(params.problem_size));
|
||||
return grid;
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
|
||||
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
||||
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
||||
if (idx_q >= get<0>(params.problem_size)) continue;
|
||||
ElementAcc acc = 0;
|
||||
auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O);
|
||||
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO);
|
||||
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO);
|
||||
auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse);
|
||||
auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse);
|
||||
|
||||
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<2>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||
Element value_O[kElementsPerLoad];
|
||||
Element value_dO[kElementsPerLoad];
|
||||
|
||||
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
|
||||
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
|
||||
|
||||
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||
acc += value_O[v] * value_dO[v];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < kNumThreadsD; i *= 2) {
|
||||
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
*ptr_sum_OdO_bhq = params.sum_odo_scale * acc;
|
||||
if (params.ptr_scaled_lse) {
|
||||
*ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
85
examples/77_blackwell_fmha/kernel/fmha_options.hpp
Normal file
85
examples/77_blackwell_fmha/kernel/fmha_options.hpp
Normal file
@ -0,0 +1,85 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
struct find_option;
|
||||
|
||||
template<auto kTag, typename Default>
|
||||
struct find_option<kTag, Default> {
|
||||
using option_value = Default;
|
||||
};
|
||||
|
||||
template<auto kTag, typename Default, typename Option, typename... Options>
|
||||
struct find_option<kTag, Default, Option, Options...> :
|
||||
std::conditional_t<
|
||||
Option::tag == kTag,
|
||||
Option,
|
||||
find_option<kTag, Default, Options...>
|
||||
>
|
||||
{};
|
||||
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
|
||||
|
||||
enum class Tag {
|
||||
kIsPersistent,
|
||||
kNumMmaWarpGroups,
|
||||
kLoadsQSeparately,
|
||||
|
||||
kIsMainloopLocked,
|
||||
kIsEpilogueLocked,
|
||||
|
||||
kStagesQ,
|
||||
kStagesKV,
|
||||
|
||||
kEpilogueKind,
|
||||
|
||||
kBlocksPerSM,
|
||||
kClusterM,
|
||||
|
||||
kAccQK
|
||||
};
|
||||
|
||||
template<auto kTag, class Value>
|
||||
struct Option {
|
||||
static constexpr auto tag = kTag;
|
||||
using option_value = Value;
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
162
examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
162
examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
@ -0,0 +1,162 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct IndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size));
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct PersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_h;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
|
||||
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, bidh;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,519 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/arch/tmem_allocator_sm100.hpp"
|
||||
|
||||
#include "kernel/fmha_options.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "collective/fmha_common.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
struct Sm100FmhaCtxKernelWarpspecializedSchedule {
|
||||
|
||||
enum class WarpRole {
|
||||
Softmax0,
|
||||
Softmax1,
|
||||
Correction,
|
||||
MMA,
|
||||
Load,
|
||||
Epilogue,
|
||||
Empty
|
||||
};
|
||||
|
||||
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
int wg_idx = warp_idx / 4; // warp_idx
|
||||
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
|
||||
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
|
||||
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
|
||||
if (warp_idx == 12) return WarpRole::MMA; // 12
|
||||
if (warp_idx == 13) return WarpRole::Load; // 13
|
||||
if (warp_idx == 14) return WarpRole::Epilogue; // 14
|
||||
return WarpRole::Empty; // 15
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = 4;
|
||||
static const int NumWarpsCorrection = 4;
|
||||
static const int NumWarpsEpilogue = 1;
|
||||
static const int NumWarpsLoad = 1;
|
||||
|
||||
static const bool kDebugUsingPrintf = false;
|
||||
static const int NumRegsSoftmax = 192;
|
||||
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = 16;
|
||||
|
||||
};
|
||||
|
||||
template<
|
||||
class ProblemShapeIn,
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class TileScheduler,
|
||||
class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule
|
||||
>
|
||||
struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ProblemShape = ProblemShapeIn;
|
||||
|
||||
using WarpRole = typename KernelSchedule::WarpRole;
|
||||
|
||||
constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
return KernelSchedule::warp_idx_to_WarpRole(warp_idx);
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax;
|
||||
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
|
||||
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
|
||||
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
|
||||
|
||||
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
|
||||
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
|
||||
static const int NumRegsOther = KernelSchedule::NumRegsOther;
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = KernelSchedule::NumWarps;
|
||||
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using TmemAllocator = cute::TMEM::Allocator1Sm;
|
||||
|
||||
struct SharedStorage {
|
||||
typename CollectiveMainloop::TensorStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
|
||||
struct PipelineStorage {
|
||||
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
|
||||
alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi;
|
||||
alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01;
|
||||
} pipelines;
|
||||
|
||||
uint32_t tmem_base_ptr;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_shape;
|
||||
typename CollectiveMainloop::Arguments mainloop;
|
||||
typename CollectiveEpilogue::Arguments epilogue;
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_shape;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return Params{
|
||||
args.problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
|
||||
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) {
|
||||
return apply_variable_length(params.problem_shape, batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
auto role = warp_idx_to_WarpRole(warp_idx);
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (role == WarpRole::Load && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
if (role == WarpRole::Epilogue && lane_predicate) {
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
|
||||
pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ;
|
||||
typename CollectiveMainloop::PipelineQ pipeline_load_q(
|
||||
shared_storage.pipelines.load_q,
|
||||
pipeline_load_q_params,
|
||||
ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load);
|
||||
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadKV;
|
||||
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
|
||||
shared_storage.pipelines.load_kv,
|
||||
pipeline_load_kv_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s0(
|
||||
shared_storage.pipelines.mma_s0,
|
||||
pipeline_mma_s0_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s1(
|
||||
shared_storage.pipelines.mma_s1,
|
||||
pipeline_mma_s1_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params;
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s0_corr(
|
||||
shared_storage.pipelines.s0_corr,
|
||||
pipeline_s0_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params;
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s1_corr(
|
||||
shared_storage.pipelines.s1_corr,
|
||||
pipeline_s1_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineO pipeline_mma_corr(
|
||||
shared_storage.pipelines.mma_corr,
|
||||
pipeline_mma_corr_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params;
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Epilogue) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
|
||||
shared_storage.pipelines.corr_epi,
|
||||
pipeline_corr_epi_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01;
|
||||
params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0;
|
||||
params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax order_s01(
|
||||
shared_storage.pipelines.order_s01, params_order_s01);
|
||||
|
||||
TmemAllocator tmem_allocator;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
pipeline_load_q.init_masks(ClusterShape{});
|
||||
pipeline_load_kv.init_masks(ClusterShape{});
|
||||
pipeline_mma_s0.init_masks(ClusterShape{});
|
||||
pipeline_mma_s1.init_masks(ClusterShape{});
|
||||
pipeline_mma_corr.init_masks(ClusterShape{});
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state;
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineQ>();
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state;
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineKV>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state;
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineE>();
|
||||
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||
|
||||
CollectiveMainloop mainloop;
|
||||
CollectiveEpilogue epilogue;
|
||||
|
||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||
warpgroup_reg_set<NumRegsSoftmax>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_softmax_0 = role == WarpRole::Softmax0;
|
||||
|
||||
mainloop.softmax(
|
||||
is_softmax_0 ? 0 : 1, blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1,
|
||||
is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state,
|
||||
is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr,
|
||||
is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state,
|
||||
order_s01
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Correction) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.correction(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
if constexpr (NumWarpsEpilogue == 0) {
|
||||
static_assert(NumWarpsCorrection == 1);
|
||||
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::MMA) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
mainloop.mma(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_consumer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_consumer_state,
|
||||
pipeline_mma_s0, pipeline_mma_s0_producer_state,
|
||||
pipeline_mma_s1, pipeline_mma_s1_producer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.load(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.mainloop, params.problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_producer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_producer_state
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Epilogue) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
epilogue.store(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.epilogue, params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_corr_epi, pipeline_corr_epi_consumer_state
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
static_assert(NumWarpsEpilogue <= 1);
|
||||
if constexpr (NumWarpsEpilogue == 1) {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::Empty) {
|
||||
warpgroup_reg_set<NumRegsEmpty>();
|
||||
|
||||
/* no-op, donate regs and exit */
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -0,0 +1,576 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/arch/tmem_allocator_sm100.hpp"
|
||||
|
||||
#include "kernel/fmha_options.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
struct Sm100FmhaGenKernelWarpspecializedSchedule {
|
||||
|
||||
enum class WarpRole {
|
||||
Softmax0,
|
||||
Softmax1,
|
||||
Correction,
|
||||
MMA,
|
||||
Load,
|
||||
Epilogue,
|
||||
Empty
|
||||
};
|
||||
|
||||
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
if (warp_idx == 0) return WarpRole::Softmax0; // 0 - 3
|
||||
if (warp_idx == 1) return WarpRole::MMA; // 12
|
||||
if (warp_idx == 2 || warp_idx == 3) return WarpRole::Load; // 13
|
||||
if (warp_idx == 4) return WarpRole::Softmax1; // 4 - 7
|
||||
if (warp_idx == 8) return WarpRole::Correction; // 8 - 11
|
||||
return WarpRole::Empty; // 15
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = 1;
|
||||
static const int NumWarpsCorrection = 1;
|
||||
static const int NumWarpsEpilogue = 0;
|
||||
static const int NumWarpsLoad = 2;
|
||||
|
||||
static const int NumRegsSoftmax = 192;
|
||||
static const int NumRegsCorrection = 104;
|
||||
static const int NumRegsOther = 248;
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = 12;
|
||||
|
||||
};
|
||||
|
||||
template<
|
||||
class ProblemShapeIn,
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class TileScheduler,
|
||||
class KernelSchedule = Sm100FmhaGenKernelWarpspecializedSchedule
|
||||
>
|
||||
struct Sm100FmhaGenKernelWarpspecialized {
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ProblemShape = decltype(replace<0>(ProblemShapeIn{}, 0));
|
||||
|
||||
using WarpRole = typename KernelSchedule::WarpRole;
|
||||
|
||||
constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
return KernelSchedule::warp_idx_to_WarpRole(warp_idx);
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax;
|
||||
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
|
||||
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
|
||||
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
|
||||
|
||||
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
|
||||
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
|
||||
static const int NumRegsOther = KernelSchedule::NumRegsOther;
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = KernelSchedule::NumWarps;
|
||||
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using TmemAllocator = cute::TMEM::Allocator1Sm;
|
||||
|
||||
struct SharedStorage {
|
||||
typename CollectiveMainloop::TensorStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
|
||||
struct PipelineStorage {
|
||||
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
|
||||
alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi;
|
||||
alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01;
|
||||
} pipelines;
|
||||
|
||||
uint32_t tmem_base_ptr;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
using StrideQOrig = typename CollectiveMainloop::StrideQOrig;
|
||||
using StrideOOrig = typename CollectiveMainloop::StrideOOrig;
|
||||
using StrideQ = typename CollectiveMainloop::StrideQ;
|
||||
using StrideO = typename CollectiveMainloop::StrideO;
|
||||
using StrideCacheK = typename CollectiveMainloop::StrideCacheK;
|
||||
using StrideCacheV = typename CollectiveMainloop::StrideCacheV;
|
||||
using StrideNewK = typename CollectiveMainloop::StrideNewK;
|
||||
using StrideNewV = typename CollectiveMainloop::StrideNewV;
|
||||
using Element = typename CollectiveMainloop::Element;
|
||||
using ElementAcc = typename CollectiveMainloop::ElementAcc;
|
||||
using ElementOut = typename CollectiveMainloop::ElementOut;
|
||||
|
||||
struct Arguments {
|
||||
// _1, max_seqlen_k, head_dim, ((h_g, h_kv), b)
|
||||
ProblemShapeIn problem_shape;
|
||||
const int* seqlen_kv;
|
||||
const int* cache_batch_idx;
|
||||
|
||||
const Element* ptr_q; // 1 x D x (H x B)
|
||||
StrideQOrig dQ;
|
||||
const Element* ptr_new_k; // 1 x D x (H x B)
|
||||
StrideNewK dNewK;
|
||||
const Element* ptr_new_v; // 1 x D x (H x B)
|
||||
StrideNewV dNewV;
|
||||
|
||||
Element* ptr_cache_k; // seqlen_max x D x (H x B)
|
||||
StrideCacheK dCacheK;
|
||||
Element* ptr_cache_v; // seqlen_max x D x (H x B)
|
||||
StrideCacheV dCacheV;
|
||||
ElementOut* ptr_o; // 1 x D x (H x B)
|
||||
StrideOOrig dO;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
ElementAcc scale_softmax = 0.0f;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_shape;
|
||||
const int* seqlen_kv;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
ProblemShape problem_shape = replace<0>(args.problem_shape, static_cast<int>(get<0>(args.problem_shape)));
|
||||
CUTE_STATIC_ASSERT_V(get<0>(args.problem_shape) == _1{});
|
||||
StrideQ dQ = replace<0>(args.dQ, 0);
|
||||
StrideO dO = replace<0>(args.dO, 0);
|
||||
get<0>(problem_shape) = get<3,0,0>(args.problem_shape);
|
||||
get<3,0,0>(problem_shape) = 1;
|
||||
get<0>(dQ) = get<2,0,0>(dQ);
|
||||
get<0>(dO) = get<2,0,0>(dO);
|
||||
|
||||
typename CollectiveMainloop::Arguments mainloop_args {
|
||||
{
|
||||
args.cache_batch_idx,
|
||||
args.ptr_q, dQ,
|
||||
args.ptr_new_k, args.dNewK,
|
||||
args.ptr_new_v, args.dNewV,
|
||||
args.ptr_cache_k, args.dCacheK,
|
||||
args.ptr_cache_v, args.dCacheV,
|
||||
},
|
||||
args.scale_softmax
|
||||
};
|
||||
|
||||
typename CollectiveEpilogue::Arguments epilogue_args {
|
||||
args.ptr_o, dO,
|
||||
};
|
||||
|
||||
return Params{
|
||||
problem_shape,
|
||||
args.seqlen_kv,
|
||||
CollectiveMainloop::to_underlying_arguments(problem_shape, mainloop_args, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(problem_shape, epilogue_args, workspace),
|
||||
TileScheduler::to_underlying_arguments(problem_shape, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) {
|
||||
ProblemShape result = problem_shape;
|
||||
get<1>(result) = params.seqlen_kv[batch_idx];
|
||||
if (params.mainloop.load.ptr_new_k != nullptr) {
|
||||
get<1>(result) += 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
auto role = warp_idx_to_WarpRole(warp_idx);
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (role == WarpRole::Load && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
if (role == WarpRole::Epilogue && lane_predicate) {
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_q_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineQ pipeline_load_q(
|
||||
shared_storage.pipelines.load_q,
|
||||
pipeline_load_q_params,
|
||||
ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_kv_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
|
||||
shared_storage.pipelines.load_kv,
|
||||
pipeline_load_kv_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s0(
|
||||
shared_storage.pipelines.mma_s0,
|
||||
pipeline_mma_s0_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s1(
|
||||
shared_storage.pipelines.mma_s1,
|
||||
pipeline_mma_s1_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params;
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s0_corr(
|
||||
shared_storage.pipelines.s0_corr,
|
||||
pipeline_s0_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params;
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s1_corr(
|
||||
shared_storage.pipelines.s1_corr,
|
||||
pipeline_s1_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineO pipeline_mma_corr(
|
||||
shared_storage.pipelines.mma_corr,
|
||||
pipeline_mma_corr_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params;
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Epilogue) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
|
||||
shared_storage.pipelines.corr_epi,
|
||||
pipeline_corr_epi_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01;
|
||||
params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0;
|
||||
params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax order_s01(
|
||||
shared_storage.pipelines.order_s01, params_order_s01);
|
||||
|
||||
TmemAllocator tmem_allocator;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
pipeline_load_q.init_masks(ClusterShape{});
|
||||
pipeline_load_kv.init_masks(ClusterShape{});
|
||||
pipeline_mma_s0.init_masks(ClusterShape{});
|
||||
pipeline_mma_s1.init_masks(ClusterShape{});
|
||||
pipeline_mma_corr.init_masks(ClusterShape{});
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state;
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineQ>();
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state;
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineKV>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state;
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineE>();
|
||||
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||
|
||||
CollectiveMainloop mainloop;
|
||||
CollectiveEpilogue epilogue(params.epilogue);
|
||||
|
||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||
warpgroup_reg_set<NumRegsSoftmax>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_softmax_0 = role == WarpRole::Softmax0;
|
||||
|
||||
mainloop.softmax(
|
||||
is_softmax_0 ? 0 : 1, blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1,
|
||||
is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state,
|
||||
is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr,
|
||||
is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state,
|
||||
order_s01
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Correction) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.correction(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
if constexpr (NumWarpsEpilogue == 0) {
|
||||
static_assert(NumWarpsCorrection == 1);
|
||||
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::MMA) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
mainloop.mma(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_consumer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_consumer_state,
|
||||
pipeline_mma_s0, pipeline_mma_s0_producer_state,
|
||||
pipeline_mma_s1, pipeline_mma_s1_producer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.load(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.mainloop, params.problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_producer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_producer_state
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Epilogue) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
epilogue.store(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.epilogue, params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_corr_epi, pipeline_corr_epi_consumer_state
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
static_assert(NumWarpsEpilogue <= 1);
|
||||
if constexpr (NumWarpsEpilogue == 1) {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::Empty) {
|
||||
warpgroup_reg_set<NumRegsEmpty>();
|
||||
|
||||
/* no-op, donate regs and exit */
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
197
examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp
Normal file
197
examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp
Normal file
@ -0,0 +1,197 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
File diff suppressed because it is too large
Load Diff
160
examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp
Normal file
160
examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp
Normal file
@ -0,0 +1,160 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user