CUTLASS 3.8 Release (#2059)
* CUTLASS 3.8 Release
* update
* Update README.md
* Revert "Update README.md"
This reverts commit b353e36fe8.
* update
* update
---------
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
73
ACTIVE_DEVELOPERS.md
Normal file
73
ACTIVE_DEVELOPERS.md
Normal file
@ -0,0 +1,73 @@
|
||||

|
||||
|
||||
[README](./README.md#documentation) > **Active Developers**
|
||||
|
||||
# CUTLASS Developers **
|
||||
|
||||
Andrew Kerr (CUTLASS founding member)<br />
|
||||
Dustyn Blasig<br />
|
||||
Albert Xu<br />
|
||||
Junkai Wu<br />
|
||||
Xiuxia Zhang<br />
|
||||
Haicheng Wu (CUTLASS founding member)<br />
|
||||
Jack Yang<br />
|
||||
Pradeep Ramani (CUTLASS 3.x founding member)<br />
|
||||
Aditya Atluri<br />
|
||||
Han Li<br />
|
||||
Nick Zhao<br />
|
||||
Ivan Yin<br />
|
||||
Yu-Jung Chen<br />
|
||||
Markus Hoehnerbach<br />
|
||||
Honghao Lu<br />
|
||||
Mihir Awatramani<br />
|
||||
Hao Sheng<br />
|
||||
Zekun Fan<br />
|
||||
Aniket Shivam<br />
|
||||
Siyu Liu<br />
|
||||
Richard Cai<br />
|
||||
Vikas Gupta<br />
|
||||
Ethan Yan<br />
|
||||
Vijay Thakkar (CUTLASS 3.x founding member)<br />
|
||||
Cris Cecka (CuTe and CUTLASS 3.x founding member)<br />
|
||||
Lawrence Ryan<br />
|
||||
Qun Song<br />
|
||||
Daniel Ricketts<br />
|
||||
dePaul Miller<br />
|
||||
Yuhan Li<br />
|
||||
Saman Ashkiani<br />
|
||||
Jack Chen<br />
|
||||
Shang Zhang<br />
|
||||
Petrick Liu<br />
|
||||
Questa Wang<br />
|
||||
Pramod Shenoy<br />
|
||||
Jack Kosaian<br />
|
||||
Yujia Zhai<br />
|
||||
Zhaodong Chen<br />
|
||||
Manas Sahni<br />
|
||||
Shunfan Shao<br />
|
||||
Fengqi Qiao<br />
|
||||
Serif Yesil<br />
|
||||
Aragorn Guan<br />
|
||||
Heidi He<br />
|
||||
Xiao Song<br />
|
||||
Sergey Klevtsov<br />
|
||||
Jiang Shao<br />
|
||||
Ruqing Xu<br />
|
||||
Mengyu Guo<br />
|
||||
Tao Xie<br />
|
||||
Linfeng Zheng<br />
|
||||
Harrison Barclay<br />
|
||||
Wenfei Tang<br />
|
||||
Diksha Gohlyan<br />
|
||||
Alexander Zhurkevich<br />
|
||||
Siyuan Fu<br />
|
||||
Hua Huang<br />
|
||||
Xiufan Liang<br />
|
||||
Ian Tramble<br />
|
||||
Ali Hassani<br />
|
||||
Shreya Gaur<br />
|
||||
|
||||
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
|
||||
|
||||
# CUTLASS Product Manager
|
||||
Matthew Nicely<br />
|
||||
60
CHANGELOG.md
60
CHANGELOG.md
@ -1,8 +1,59 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25)
|
||||
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./cutlass/media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/narrow_and_mixed_precision_gemms.md)
|
||||
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
|
||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||
- [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
|
||||
- Potential API breaking changes:
|
||||
+ Fix `cute::UniversalCopy` for type safety.
|
||||
@ -22,12 +73,7 @@
|
||||
+ [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- Improve [mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md).
|
||||
+ Added a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
+ Added [layout pre-shuffling](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L50-55) to optimize memory loading.
|
||||
+ Added [interleaved conversion](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu#L50-52) for `{INT4, UINT4, INT8}` x `{FP16, BF16}`.
|
||||
+ Other general optimizations.
|
||||
- The suffixes of the mixed input kernel schedules have been removed. Use `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` and `KernelTmaWarpSpecializedCooperative` instead.
|
||||
- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
|
||||
@ -164,6 +164,11 @@ endif()
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a)
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
@ -383,6 +388,21 @@ endif()
|
||||
|
||||
|
||||
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Blackwell features
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUDA_BLACKWELL_TMA_SWIZZLE_ENABLED=1)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUDA_ENABLE_PREFERRED_CLUSTER=1)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
# Warnings-as-error exceptions and warning suppressions for Clang builds
|
||||
if (CUTLASS_CLANG_HOST_COMPILE)
|
||||
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||

|
||||
|
||||
[README](./README.md#documentation) > **Contributors**
|
||||
|
||||
# CUTLASS Developers and Contributors
|
||||
|
||||
This is the official list of CUTLASS developers and contributors.
|
||||
|
||||
## DEVELOPERS
|
||||
Vijay Thakkar<br />
|
||||
Pradeep Ramani<br />
|
||||
Cris Cecka<br />
|
||||
Aniket Shivam<br />
|
||||
Jack Kosaian<br />
|
||||
Mark Hoemmen<br />
|
||||
Richard Cai<br />
|
||||
Honghao Lu<br />
|
||||
Ethan Yan<br />
|
||||
Haicheng Wu<br />
|
||||
Andrew Kerr<br />
|
||||
Dustyn Blasig<br />
|
||||
Fengqi Qiao<br />
|
||||
Duane Merrill<br />
|
||||
Yujia Zhai<br />
|
||||
Rawn Henry<br />
|
||||
Sergey Klevtsov<br />
|
||||
Shang Zhang<br />
|
||||
Piotr Majcher<br />
|
||||
Paul Springer<br />
|
||||
Markus Hohnerbach<br />
|
||||
Jin Wang<br />
|
||||
Aditya Atluri<br />
|
||||
|
||||
## CuTe
|
||||
Cris Cecka<br />
|
||||
Vijay Thakkar<br />
|
||||
|
||||
## CUTLASS Product Manager
|
||||
Matthew Nicely<br />
|
||||
|
||||
## Former CUTLASS Developers
|
||||
Manish Gupta<br />
|
||||
Naila Farooqui<br />
|
||||
David Tanner<br />
|
||||
Manikandan Ananth<br />
|
||||
Zhaodong Chen<br />
|
||||
Chinmay Talegaonkar<br />
|
||||
|
||||
## CONTRIBUTORS
|
||||
Timothy Costa<br />
|
||||
Julien Demouth<br />
|
||||
Brian Fahs<br />
|
||||
Michael Garland<br />
|
||||
Michael Goldfarb<br />
|
||||
Mostafa Hagog<br />
|
||||
Fei Hu<br />
|
||||
Alan Kaatz<br />
|
||||
Tina Li<br />
|
||||
Timmy Liu<br />
|
||||
Wei Liu<br />
|
||||
Tim Martin<br />
|
||||
Duane Merrill<br />
|
||||
Kevin Siu<br />
|
||||
Markus Tavenrath<br />
|
||||
John Tran<br />
|
||||
Vicki Wang<br />
|
||||
Junkai Wu<br />
|
||||
Fung Xie<br />
|
||||
Albert Xu<br />
|
||||
Yang Xu<br />
|
||||
Jack Yang<br />
|
||||
Scott Yokim<br />
|
||||
Xiuxia Zhang<br />
|
||||
Nick Zhao<br />
|
||||
|
||||
## ACKNOWLEDGEMENTS
|
||||
|
||||
Girish Bharambe<br />
|
||||
Luke Durant<br />
|
||||
Carter Edwards<br />
|
||||
Olivier Giroux<br />
|
||||
Stephen Jones<br />
|
||||
Rishkul Kulkarni<br />
|
||||
Bryce Lelbach<br />
|
||||
Joel McCormack<br />
|
||||
Kyrylo Perelygin<br />
|
||||
Sean Treichler<br />
|
||||
235
README.md
235
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 3.7.0
|
||||
# CUTLASS 3.8.0
|
||||
|
||||
_CUTLASS 3.7.0 - January 2025_
|
||||
_CUTLASS 3.8.0 - January 2025_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -16,71 +16,96 @@ as building blocks within custom kernels and applications.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32),
|
||||
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16,
|
||||
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
8b floating point types (e5m2 and e4m3),
|
||||
block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8),
|
||||
narrow integer types (4 and 8b signed and unsigned integers),
|
||||
and binary 1b data types (where architectures allow for the
|
||||
native support of such data types).
|
||||
CUTLASS demonstrates optimal matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, Ampere, and Hopper architectures.
|
||||
NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via
|
||||
the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution
|
||||
operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](./media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
See the [functionality docs](./media/docs/functionality.md) for a more comprehensive
|
||||
list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU
|
||||
architecture.
|
||||
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
|
||||
# What's New in CUTLASS 3.8
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture.
|
||||
For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8.
|
||||
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
|
||||
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](./media/docs/cute/00_quickstart.md).
|
||||
* Support for new CuTe building blocks specifically for Blackwell architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp#L290) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./cutlass/media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/narrow_and_mixed_precision_gemms.md)
|
||||
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
Note: CUTLASS 3.x builds are known to be broken on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
# What's New in CUTLASS 3.7
|
||||
|
||||
CUTLASS 3.7.0 is an update to CUTLASS adding:
|
||||
|
||||
- A new [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) where the operands and block scaling tensor are staged via shared memory.
|
||||
- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is an experimental pipelined Tensor Parallelism implementation utilizing existing CUTLASS kernels and CUDA runtime features, which can hide the most of communication behind computation.
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
|
||||
- Potential API breaking changes:
|
||||
+ Fix `cute::UniversalCopy` for type safety.
|
||||
+ No longer implicitly select `cute::SM80_CP_ASYNC_*` based on input tensors. This avoids implicit downstream synchronization requirements. To use `SM80_CP_ASYNC`, users must explicitly select the appropriate CopyAtom.
|
||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/profiler.md#cutlass-profiler).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
Starting from CUTLASS 3.0, CUTLASS removed support for the following:
|
||||
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- C++ language versions less than 17.
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit peak performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows the continual CUTLASS performance improvements
|
||||
they exhibit nearly optimal utilization of peak theoretical throughput. The figure below
|
||||
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
|
||||
on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg></p>
|
||||
|
||||
The two figures below show the continual CUTLASS performance improvements
|
||||
on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since
|
||||
CUTLASS 3.1.
|
||||
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
@ -88,20 +113,45 @@ Tensor Core operations are implemented using CUDA's
|
||||
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
|
||||
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
|
||||
|
||||
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
|
||||
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
|
||||
as shown in the above figure. Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
# CuTe
|
||||
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for
|
||||
defining and operating on hierarchically multidimensional layouts of threads and data.
|
||||
CuTe provides `Layout` and `Tensor` objects that compactly package the type,
|
||||
shape, memory space, and layout of data, while performing the complicated indexing for the user.
|
||||
This lets programmers focus on the logical descriptions of their algorithms while
|
||||
CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design,
|
||||
implement, and modify all dense linear algebra operations.
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts
|
||||
which can be composed with data arrays to represent tensors.
|
||||
The representation of layouts is powerful enough to represent nearly
|
||||
everything we need to implement efficient dense linear algebra.
|
||||
Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||
This greatly simplifies the design and improves code composability and readability.
|
||||
More documentation specific to CuTe can be found in its
|
||||
[dedicated documentation directory](./media/docs/cute/00_quickstart.md).
|
||||
|
||||
# Compatibility
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta (compute capability 7.0)
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.4 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
|
||||
performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions.
|
||||
|
||||
## Operating Systems
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
@ -109,47 +159,74 @@ We have tested the following environments.
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
| Ubuntu 22.04 | Clang 10.0.0 |
|
||||
| Ubuntu 22.04 | Clang 14.0.6 |
|
||||
| Ubuntu 22.04 | Clang 17.0.6 |
|
||||
| Windows 10.0 | Visual Studio 2019 v16.11.27 |
|
||||
|
||||
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
|
||||
|
||||
Note: CUTLASS 3.x builds are known to be broken on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
## Hardware
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**|
|
||||
|---|---|---|
|
||||
|NVIDIA V100 Tensor Core GPU |7.0|11.4|
|
||||
|NVIDIA TitanV |7.0|11.4|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4|
|
||||
|NVIDIA GeForce RTX 20x0 series |7.5|11.4|
|
||||
|NVIDIA T4 |7.5|11.4|
|
||||
|NVIDIA A100 Tensor Core GPU |8.0|11.4|
|
||||
|NVIDIA A10 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 3090 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 4090 |8.9|11.8|
|
||||
|NVIDIA GeForce RTX 30x0 series |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 40x0 series |8.9|11.8|
|
||||
|NVIDIA L40 |8.9|11.8|
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
In general, PTX code generated for one target architecture can be run on future architectures
|
||||
(i.e., it is forward compatible).
|
||||
However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose
|
||||
PTX does not have forward compatibility guarantees.
|
||||
Several Hopper and Blackwell PTX instructions fall under this category of
|
||||
architecture-accelerated features, and thus require a `sm_90a` or `sm100a` target architecture
|
||||
(note the "a" appended). For more details on this and other architecture-accelerated instructions,
|
||||
please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CUDA Toolkit 12 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag
|
||||
`CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100,
|
||||
users are required to build CUTLASS with `90a` as the target architecture.
|
||||
If a user accidentally builds a kernel which uses SM90a features
|
||||
(e.g. Hopper Tensor Core Instructions), using the SM90 target
|
||||
(note the lack of "a"), with either CUDA Toolkit 12 or 11.8,
|
||||
the kernel is expected to fail with a runtime error.
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
```
|
||||
Or
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
```
|
||||
|
||||
Please refer to the [functionality documentation](./media/docs/functionality.md) for details on which kernels require which target architectures.
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
Please refer to the [functionality documentation](./media/docs/functionality.md)
|
||||
for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](./media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS
|
||||
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
@ -163,7 +240,7 @@ CUTLASS is described in the following documents and the accompanying
|
||||
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development
|
||||
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
@ -171,11 +248,11 @@ kernels in the same stream, and how it is used in CUTLASS.
|
||||
We have also described the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
|
||||
@ -489,6 +489,14 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -540,6 +540,15 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (__CUDACC_VER_MAJOR__ < 12 || props.major != 9 || props.minor != 0) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture "
|
||||
<< "(compute capability 90) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -356,6 +356,15 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (__CUDACC_VER_MAJOR__ < 12 || props.major != 9 || props.minor != 0) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture "
|
||||
<< "(compute capability 90) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -626,6 +626,13 @@ int main(int argc, const char ** argv) {
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -750,6 +750,13 @@ int main(int argc, char const **argv)
|
||||
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
|
||||
if (notSupported) {
|
||||
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
|
||||
}
|
||||
|
||||
@ -572,6 +572,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -619,6 +619,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -524,6 +524,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -489,6 +489,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -7,14 +7,18 @@ When relying on `KernelScheduleAuto`, the main loop supporting different A and B
|
||||
|
||||
This first version only supports mixed type GEMMs using TMA.
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type.
|
||||
While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type.
|
||||
|
||||
The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array<ElementScale, 8>` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now.
|
||||
|
||||
Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details.
|
||||
|
||||
We are currently optimizing the following cases:
|
||||
1. Memory bound cases for all types
|
||||
2. `fp8 x {int2, uint2}` case
|
||||
|
||||
## Limitations
|
||||
|
||||
|
||||
@ -151,16 +151,16 @@ void mixed_dtype_profiling(
|
||||
runtimes.reserve(options.iterations);
|
||||
|
||||
for (int iter = 0; iter < options.warmup + options.iterations; ++iter) {
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventRecord(start);
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
cudaEventRecord(stop);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
if (iter >= options.warmup) {
|
||||
float milliseconds = 0;
|
||||
cudaEventElapsedTime(&milliseconds, start, stop);
|
||||
runtimes.push_back(milliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
cudaEventDestroy(start);
|
||||
|
||||
@ -33,6 +33,8 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
@ -197,7 +199,6 @@ bool initialize_packed_scale(
|
||||
{
|
||||
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
||||
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
||||
// std::cout << data_in[i] << ":" << std::hex << static_cast<uint16_t>(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast<uint16_t>((-data_in[i]).storage) << std::endl;
|
||||
}
|
||||
try {
|
||||
block_out.copy_from_host(data_out.data());
|
||||
|
||||
@ -519,6 +519,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -737,6 +737,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -507,6 +507,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -576,6 +576,14 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -475,6 +475,13 @@ int main(int argc, char const **args) {
|
||||
<< "later (compute capability 90 or greater).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
else if (props.major != 9 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
@ -133,7 +133,8 @@ using namespace cute;
|
||||
using TP = _8;
|
||||
static constexpr int TP_ = TP{};
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
// Distributed GEMM tiling/sharding schedule
|
||||
// Choices:
|
||||
@ -344,7 +345,8 @@ struct Result {
|
||||
|
||||
};
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && \
|
||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
|
||||
483
examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu
Normal file
483
examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu
Normal file
@ -0,0 +1,483 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A FP16 dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
||||
|
||||
This example demonstrates minimal set of changes needed to transition from a Hopper CUTLASS 3.x
|
||||
GEMM kernel (see example 48_hopper_warp_specialized_gemm) to a Blackwell 3.x CUTLASS GEMM kernel.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features:
|
||||
|
||||
1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a)
|
||||
which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA).
|
||||
|
||||
Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a).
|
||||
Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the
|
||||
Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
$ ./examples/70_blackwell_gemm/70_blackwell_fp16_gemm --m=8192 --n=8192 --k=8192
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
|
||||
using AtomThrShape_MNK = Shape<_2, _1, _1>;
|
||||
// Shape of the tile computed by each SM
|
||||
using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
PerSmTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(8192), n(8192), k(8192),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "70_blackwell_fp16_gemm\n\n"
|
||||
<< " Blackwell FP16 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "70_blackwell_fp16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 100a.
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
671
examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu
Normal file
671
examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu
Normal file
@ -0,0 +1,671 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A FP8 dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS.
|
||||
|
||||
This example demonstrates minimal set of changes needed to transition from a Hopper CUTLASS 3.x
|
||||
FP8 GEMM kernel (see example 54_hopper_fp8_warp_specialized_gemm) to a Blackwell SM100 FP8 GEMM kernel.
|
||||
|
||||
This example shows all important fusions used by FP8 gemm kernels,
|
||||
i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses of the following Blackwell SM100 features:
|
||||
|
||||
1. New series of Tensor Core MMA Instructions (tcgen05) introduced on the Blackwell architecture (sm100a)
|
||||
which have 2x throughput compared to Hopper Tensor Core MMA instructions (WGMMA).
|
||||
|
||||
Note that Hopper WGMMA Tensor Core MMA instructions are not compatible on Blackwell (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
2. A new per-SM memory called Tensor Memory (TMEM) introduced on the Blackwell architecture (sm100a).
|
||||
Blackwell SM100 Tensor Core MMA instructions store their accumulation results in TMEM instead of the
|
||||
Register File. (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. An extended flavor of the warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
$ ./examples/70_blackwell_gemm/70_blackwell_fp8_gemm --m=8192 --n=8192 --k=8192
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementC;
|
||||
using LayoutAux = LayoutC;
|
||||
using ElementAmax = float;
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the threadblocks in a cluster
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
|
||||
using AtomThrShape_MNK = Shape<_2, _1, _1>;
|
||||
// Shape of the tile computed by each SM
|
||||
using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
|
||||
|
||||
using FusionOp = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
|
||||
LayoutC, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
PerSmTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutC, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
FusionOp
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Extract information from Gemm kernel.
|
||||
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
|
||||
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
|
||||
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
|
||||
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideAux = StrideC;
|
||||
|
||||
constexpr bool IsDFp8 =
|
||||
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
||||
|
||||
constexpr bool IsAuxFp8 =
|
||||
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
||||
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
StrideAux stride_aux;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<ElementA , LayoutA > tensor_A;
|
||||
cutlass::HostTensor<ElementB , LayoutB > tensor_B;
|
||||
cutlass::HostTensor<ElementC , LayoutC > tensor_C;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_D;
|
||||
cutlass::HostTensor<ElementD , LayoutD > tensor_ref_D;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_aux;
|
||||
cutlass::HostTensor<ElementAux, LayoutAux> tensor_ref_aux;
|
||||
|
||||
using LayoutScalar = cutlass::layout::PackedVectorLayout;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_alpha;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scalar_beta;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_A;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_B;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_C;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_D;
|
||||
cutlass::HostTensor<ElementScalar, LayoutScalar> scale_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_D;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> abs_max_aux;
|
||||
cutlass::HostTensor<ElementAmax , LayoutScalar> reference_abs_max_aux;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f;
|
||||
bool device_scale = false;
|
||||
bool save_aux = true;
|
||||
bool save_amax = true;
|
||||
int iterations = 1000;
|
||||
int m = 1024, n = 512, k = 1024, l = 1;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("l", l);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("scale_a", scale_a, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_b", scale_b, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_c", scale_c, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_d", scale_d, 1.f);
|
||||
cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f);
|
||||
cmd.get_cmd_line_argument("device_scale", device_scale, false);
|
||||
cmd.get_cmd_line_argument("save_aux", save_aux, true);
|
||||
cmd.get_cmd_line_argument("save_amax", save_amax, true);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "70_blackwell_fp8_gemm\n\n"
|
||||
<< " Blackwell FP8 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the l extent (batch) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --scale_a=<f32> Scaling factor for A\n"
|
||||
<< " --scale_b=<f32> Scaling factor for B\n"
|
||||
<< " --scale_c=<f32> Scaling factor for C\n"
|
||||
<< " --scale_d=<f32> Scaling factor for D (ignored for non-fp8 D)\n"
|
||||
<< " --scale_aux=<f32> Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n"
|
||||
<< " --device_scale=<bool> Copy scalars to device memory before kernel launch (default: false)\n"
|
||||
<< " --save_aux=<bool> Save the pre-activation as an auxiliary tensor (default: true)\n"
|
||||
<< " --save_amax=<bool> Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "70_blackwell_fp8_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
int bits_output = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if (bits_output == 16) {
|
||||
scope_max = 5;
|
||||
scope_min = -5;
|
||||
}
|
||||
else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
|
||||
stride_aux = stride_D;
|
||||
|
||||
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
|
||||
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
|
||||
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
|
||||
|
||||
tensor_A.resize(a_coord);
|
||||
tensor_B.resize(b_coord);
|
||||
tensor_C.resize(c_coord);
|
||||
tensor_D.resize(c_coord);
|
||||
tensor_ref_D.resize(c_coord);
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), seed + 2022);
|
||||
initialize_tensor(tensor_B.host_view(), seed + 2023);
|
||||
initialize_tensor(tensor_C.host_view(), seed + 2024);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D.sync_device();
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.resize(c_coord);
|
||||
tensor_aux.sync_device();
|
||||
tensor_ref_aux.resize(c_coord);
|
||||
}
|
||||
|
||||
if (options.device_scale) {
|
||||
scalar_alpha.resize(cutlass::make_Coord(1));
|
||||
scalar_beta.resize(cutlass::make_Coord(1));
|
||||
scale_A.resize(cutlass::make_Coord(1));
|
||||
scale_B.resize(cutlass::make_Coord(1));
|
||||
scale_C.resize(cutlass::make_Coord(1));
|
||||
scale_D.resize(cutlass::make_Coord(1));
|
||||
scale_aux.resize(cutlass::make_Coord(1));
|
||||
|
||||
cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha);
|
||||
cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta);
|
||||
cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a);
|
||||
cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b);
|
||||
cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c);
|
||||
cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d);
|
||||
cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux);
|
||||
|
||||
scalar_alpha.sync_device();
|
||||
scalar_beta.sync_device();
|
||||
scale_A.sync_device();
|
||||
scale_B.sync_device();
|
||||
scale_C.sync_device();
|
||||
scale_D.sync_device();
|
||||
scale_aux.sync_device();
|
||||
}
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.resize(cutlass::make_Coord(1));
|
||||
abs_max_D.sync_device();
|
||||
reference_abs_max_D.resize(cutlass::make_Coord(1));
|
||||
}
|
||||
|
||||
if (IsAuxFp8 && options.save_aux && options.save_amax) {
|
||||
abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
abs_max_aux.sync_device();
|
||||
reference_abs_max_aux.resize(cutlass::make_Coord(1));
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, options.l},
|
||||
{tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B},
|
||||
{
|
||||
{}, // epilogue.thread
|
||||
tensor_C.device_data(), stride_C,
|
||||
tensor_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
auto &fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.alpha_ptr = scalar_alpha.device_data();
|
||||
fusion_args.beta_ptr = scalar_beta.device_data();
|
||||
fusion_args.scale_a = options.scale_a;
|
||||
fusion_args.scale_b = options.scale_b;
|
||||
fusion_args.scale_c = options.scale_c;
|
||||
fusion_args.scale_a_ptr = scale_A.device_data();
|
||||
fusion_args.scale_b_ptr = scale_B.device_data();
|
||||
fusion_args.scale_c_ptr = scale_C.device_data();
|
||||
|
||||
// ignored if tensor types are not fp8
|
||||
fusion_args.scale_d = options.scale_d;
|
||||
fusion_args.scale_aux = options.scale_aux;
|
||||
fusion_args.scale_d_ptr = scale_D.device_data();
|
||||
fusion_args.scale_aux_ptr = scale_aux.device_data();
|
||||
|
||||
// leaving/setting these as nullptr disables the fusion at runtime
|
||||
fusion_args.bias_ptr = nullptr;
|
||||
|
||||
if (options.save_aux) {
|
||||
fusion_args.aux_ptr = tensor_aux.device_data();
|
||||
fusion_args.dAux = stride_aux;
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_aux_ptr = abs_max_aux.device_data();
|
||||
}
|
||||
}
|
||||
|
||||
if (options.save_amax) {
|
||||
fusion_args.amax_D_ptr = abs_max_D.device_data();
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
auto A = cute::make_tensor(tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A));
|
||||
auto B = cute::make_tensor(tensor_B.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B));
|
||||
auto C = cute::make_tensor(tensor_C.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C));
|
||||
auto D = cute::make_tensor(tensor_ref_D.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D));
|
||||
auto Aux = cute::make_tensor(tensor_ref_aux.host_data(),
|
||||
cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_aux));
|
||||
using unused_t = decltype(D);
|
||||
|
||||
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
ElementScalar,
|
||||
ElementScalar,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
decltype(C),
|
||||
decltype(D),
|
||||
unused_t, // bias
|
||||
decltype(Aux),
|
||||
unused_t, // valpha
|
||||
unused_t, // vbeta
|
||||
ActivationFunctor
|
||||
> epilogue_params;
|
||||
|
||||
epilogue_params.C = C;
|
||||
epilogue_params.D = D;
|
||||
epilogue_params.Aux = Aux;
|
||||
epilogue_params.alpha = options.alpha;
|
||||
epilogue_params.beta = options.beta;
|
||||
epilogue_params.scale_a = options.scale_a;
|
||||
epilogue_params.scale_b = options.scale_b;
|
||||
epilogue_params.scale_c = options.scale_c;
|
||||
epilogue_params.scale_d = options.scale_d;
|
||||
epilogue_params.scale_aux = options.scale_aux;
|
||||
epilogue_params.abs_max_D = reference_abs_max_D.host_data();
|
||||
epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data();
|
||||
|
||||
// get reference result
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// compare_reference
|
||||
tensor_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view());
|
||||
|
||||
if (IsDFp8 && options.save_amax) {
|
||||
abs_max_D.sync_host();
|
||||
passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0));
|
||||
}
|
||||
|
||||
if (options.save_aux) {
|
||||
tensor_aux.sync_host();
|
||||
passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view());
|
||||
if (IsAuxFp8 && options.save_amax) {
|
||||
abs_max_aux.sync_host();
|
||||
passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0));
|
||||
}
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least sm100a.
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12) {
|
||||
std::cerr << "This example requires CUDA 12 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Run
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
41
examples/70_blackwell_gemm/CMakeLists.txt
Normal file
41
examples/70_blackwell_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,41 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp16_gemm
|
||||
70_blackwell_fp16_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
70_blackwell_fp8_gemm
|
||||
70_blackwell_fp8_gemm.cu
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,570 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Blackwell SM100 GEMM example demonstrating compatible mainloop+epilogue builder schedules
|
||||
and epilogue visitor tree (EVT) construction
|
||||
|
||||
Example usage:
|
||||
$ ./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder \
|
||||
--m=2048 --n=2048 --k=2048 --l=2
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
|
||||
int m, n, k, l;
|
||||
float alpha, beta;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
m(2048), n(2048), k(2048), l(1),
|
||||
alpha(1.f), beta(0.f)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m, 2048);
|
||||
cmd.get_cmd_line_argument("n", n, 2048);
|
||||
cmd.get_cmd_line_argument("k", k, 2048);
|
||||
cmd.get_cmd_line_argument("l", l, 1);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "71_blackwell_gemm_with_collective_builder\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " performant kernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
} else {
|
||||
scope_max = 8;
|
||||
scope_min = -8;
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective
|
||||
// operation builders by specializing the GEMM on the kernel+epilogue schedule it will use and the
|
||||
// number of pipeline stages.
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
|
||||
// Type of epilogue schedule to generate
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
// Number of pipeline stages to use
|
||||
class StageCountType = cutlass::gemm::collective::StageCountAuto,
|
||||
// Do we use custom epilogue visitor tree (EVT) fusion
|
||||
bool UseCustomEVT = false
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using LayoutD = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
using ClusterShapeMNK = Shape<_2,_2,_1>;
|
||||
static constexpr bool Use2SmMma =
|
||||
// Manually specified 2sm cluster MMA schedule, will error if cluster M is not a multiple of 2
|
||||
std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelTmaWarpSpecialized2SmSm100> ||
|
||||
// Auto schedule will try to select 2sm cluster MMA based on cluster M
|
||||
std::is_same_v<MainloopScheduleType, cutlass::gemm::collective::KernelScheduleAuto> && size<0>(ClusterShapeMNK{}) % 2 == 0;
|
||||
// The MNK layout of CTAs within a cluster MMA
|
||||
using AtomThrMNK = std::conditional_t<Use2SmMma, Shape<_2,_1,_1>, Shape<_1,_1,_1>>;
|
||||
// The MMA tile used by the mainloop collective. Blackwell 1sm MMA supports up to MMA tile M = 128, 2sm MMA supports up to MMA tile M = 256
|
||||
using MmaTileMNK = std::conditional_t<Use2SmMma, Shape<_256,_128,_64>, Shape<_128,_128,_64>>;
|
||||
// The Output tile used by the epilogue collective
|
||||
using OutputTileMNK = decltype(shape_div(MmaTileMNK{}, AtomThrMNK{}));
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
|
||||
// Blackwell fusions for the most part use the same EVT nodes used in Hopper. Most Blackwell EVTs will alias to their Hopper counterparts.
|
||||
// EVT nodes new to Blackwell mainly relate to narrow precision scale factor generation and are contained in include/cutlass/epilogue/fusion/sm100_visitor_*.hpp
|
||||
// See include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp for EVT construction using these new nodes
|
||||
// Fusions relating to narrow-precision scale factor generation are demonstrated in example 72b and can only be used in blackwell kernels
|
||||
using CustomEVT = // alpha * acc + beta * C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
|
||||
cutlass::epilogue::fusion::Sm90SrcFetch<ElementC>, // C
|
||||
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
|
||||
cutlass::epilogue::fusion::Sm90AccFetch // acc
|
||||
>
|
||||
>;
|
||||
|
||||
// As in Hopper, a predefined set of fusion operations are provided in include/cutlass/epilogue/fusion/operations.hpp and can be passed to the epilogue builder
|
||||
// Fusions operations supported by the Hopper TMA epilogue will also be supported by the Blackwell TMA epilogue
|
||||
// Fusions relating to narrow-precision scale factor generation are demonstrated in example 72b and can only be used in blackwell kernels
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
OutputTileMNK, ClusterShapeMNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD,
|
||||
EpilogueScheduleType,
|
||||
cute::conditional_t<UseCustomEVT, CustomEVT, DefaultOperation>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileMNK, ClusterShapeMNK,
|
||||
cute::conditional_t<cute::is_same_v<StageCountType, cutlass::gemm::collective::StageCountAuto>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
StageCountType>,
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
|
||||
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
|
||||
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
|
||||
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed = 0;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
bool verify(const ProblemShapeType& problem_size, float alpha, float beta) {
|
||||
auto [M, N, K, L] = problem_size;
|
||||
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));
|
||||
|
||||
cutlass::reference::device::GemmComplex(
|
||||
{M, N, K},
|
||||
ElementScalar(alpha),
|
||||
ref_A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ref_B,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementScalar(beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
ElementAccumulator(0),
|
||||
L, // batch_count
|
||||
M * K, // batch_stride_A
|
||||
K * N, // batch_stride_B
|
||||
M * N, // batch_stride_C
|
||||
M * N // batch_stride_D
|
||||
);
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const ProblemShapeType& problem_size) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
|
||||
block_A.reset(M * K * L);
|
||||
block_B.reset(K * N * L);
|
||||
block_C.reset(M * N * L);
|
||||
block_D.reset(M * N * L);
|
||||
block_ref_D.reset(M * N * L);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
|
||||
|
||||
initialize(problem_size);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{}, // epilogue.thread
|
||||
block_C.get(), stride_C, block_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
|
||||
// See example 48 for details on custom EVT construction
|
||||
if constexpr (UseCustomEVT) {
|
||||
arguments.epilogue.thread =
|
||||
{ // ternary op : beta * C + (alpha * acc)
|
||||
{{options.beta}}, // leaf op+args : beta
|
||||
{}, // leaf op+args : C
|
||||
{ // binary op : alpha * acc
|
||||
{{options.alpha}}, // leaf op+args : alpha
|
||||
{}, // leaf op+args : acc
|
||||
{} // binary args : multiplies
|
||||
}, // end binary op
|
||||
{} // ternary args : multiply_add
|
||||
}; // end ternary op
|
||||
}
|
||||
// Pre-defined fusions will have flat, named args for user-friendlyness
|
||||
else {
|
||||
arguments.epilogue.thread.alpha = options.alpha;
|
||||
arguments.epilogue.thread.beta = options.beta;
|
||||
}
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run the GEMM
|
||||
status = gemm_op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = verify(problem_size, options.alpha, options.beta);
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, bool passed) {
|
||||
std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
bool passed;
|
||||
|
||||
// Auto mainloop and epilogue schedules must be used together to guarantee functionality
|
||||
ExampleRunner<> runner_0;
|
||||
passed = runner_0.run(options, hw_info);
|
||||
print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule", passed);
|
||||
|
||||
// Mainloop stage counts can be specified manually
|
||||
// It is the user's responsibility to ensure there is enough device smem to allocate manual stage counts
|
||||
ExampleRunner<
|
||||
cutlass::gemm::collective::KernelScheduleAuto,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
_3> runner_1;
|
||||
passed = runner_1.run(options, hw_info);
|
||||
print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and 3 mainloop stages", passed);
|
||||
|
||||
// 1SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmSm100, cutlass::epilogue::NoSmemWarpSpecialized> runner_2;
|
||||
passed = runner_2.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed);
|
||||
|
||||
// 1SM cluster MMA mainloop schedules can also be used with 1SM TMA epilogue schedules
|
||||
// 1SM cluster MMA mainloop schedules will not work with 2SM TMA epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmSm100, cutlass::epilogue::TmaWarpSpecialized1Sm> runner_3;
|
||||
passed = runner_3.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed);
|
||||
|
||||
// 2SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized2SmSm100, cutlass::epilogue::NoSmemWarpSpecialized> runner_4;
|
||||
passed = runner_4.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed);
|
||||
|
||||
// 2SM cluster MMA mainloop schedules can also be used with 2SM TMA epilogue schedules
|
||||
// 2SM cluster MMA mainloop schedules will not work with SM TMA epilogue schedules
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized2SmSm100, cutlass::epilogue::TmaWarpSpecialized2Sm> runner_5;
|
||||
passed = runner_5.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with TmaWarpSpecialized2Sm epilogue schedule", passed);
|
||||
|
||||
// Blackwell Auto schedule supports custom EVT fusions
|
||||
constexpr bool UseCustomEVT = true;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::collective::KernelScheduleAuto,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_6;
|
||||
passed = runner_6.run(options, hw_info);
|
||||
print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and custom EVT", passed);
|
||||
|
||||
// 1SM TMA epilogue schedules support custom EVT fusions
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100,
|
||||
cutlass::epilogue::TmaWarpSpecialized1Sm,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_7;
|
||||
passed = runner_7.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with TmaWarpSpecialized1Sm epilogue and custom EVT", passed);
|
||||
|
||||
// 2SM TMA epilogue schedules support custom EVT fusions
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecialized2SmSm100,
|
||||
cutlass::epilogue::TmaWarpSpecialized2Sm,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_8;
|
||||
passed = runner_8.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with TmaWarpSpecialized2Sm epilogue and custom EVT", passed);
|
||||
|
||||
|
||||
// Blackwell direct store epilogue schedule supports custom EVTs and named fusion operations as well (not supported for pre-Blackwell kernels)
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100,
|
||||
cutlass::epilogue::NoSmemWarpSpecialized,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
UseCustomEVT> runner_9;
|
||||
passed = runner_9.run(options, hw_info);
|
||||
print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue and custom EVT", passed);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
71_blackwell_gemm_with_collective_builder.cu
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,544 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
|
||||
on the Blackwell architecture (sm100a) which have 2x throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
|
||||
and 4x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Similar to 70_blackwell_gemm, this kernel leverages:
|
||||
1. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
2. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
3. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
PerSmTileShape_MNK, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "72a_blackwell_nvfp4_bf16_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,594 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a blockscaled NVFP4 GEMM on the NVIDIA Blackwell SM100 architecture
|
||||
on NVIDIA Blackwell SM100 architecture. The kernel outputs quantized fp4 values with scale factors that be the input of another GEMM.
|
||||
|
||||
Similar to 72a_blackwell_nvfp4_bf16_gemm, this kernel leverages:
|
||||
1. Blockscaled tcgen05.mma instructions.
|
||||
|
||||
2. Per-SM memory called Tensor Memory (TMEM)
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::float_e2m1_t; // Element type for D matrix operand
|
||||
using ElementSFD = cutlass::float_ue8m0_t; // Element type for SFB matrix operand
|
||||
using ElementC = float; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
using LayoutSFDTag = LayoutDTag; // Layout type for SFD should be same as D matrix operand
|
||||
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_128,_128,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
using PerSmTileShape_MNK = Shape<_128,_128,_256>; // Threadblock-level tile size
|
||||
|
||||
constexpr int InputSFVectorSize = 16;
|
||||
constexpr int OutputSFVectorSize = InputSFVectorSize;
|
||||
|
||||
// D = alpha * acc + beta * C
|
||||
// With BlockScaleFactor generation.
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementCompute,
|
||||
ElementSFD, LayoutSFDTag,
|
||||
ElementC>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
PerSmTileShape_MNK, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
|
||||
FusionOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
using FusionOp = typename Gemm::EpilogueOutputOp;
|
||||
constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported;
|
||||
using SfdOutputCfg = cutlass::detail::Sm100BlockScaledOutputConfig<OutputSFVectorSize>;
|
||||
using LayoutSFD = typename SfdOutputCfg::LayoutSF;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
LayoutSFD layout_SFD;
|
||||
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensors
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_SFD;
|
||||
// Reference Output Tensors
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
cutlass::HostTensor<ElementSFD, cutlass::layout::PackedVectorLayout> block_reference_SFD;
|
||||
// Matrix-wide normalization constant
|
||||
cutlass::HostTensor<ElementCompute, cutlass::layout::PackedVectorLayout> block_Normconst;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "72b_blackwell_nvfp4_nvfp4_gemm\n\n"
|
||||
<< " Blackwell NVFP4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
// For SFD tensor layout
|
||||
using Sm100BlockScaledOutputConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFD = SfdOutputCfg::tile_atom_to_shape_SFD(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
block_Normconst.reset(cutlass::make_Coord(1));
|
||||
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
block_SFD.reset(cutlass::make_Coord(size(filter_zeros(layout_SFD))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
block_Normconst.at(cutlass::make_Coord(0)) = 2;
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_D.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
block_SFD.sync_device();
|
||||
block_Normconst.sync_device();
|
||||
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{ options.alpha, options.beta },
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D}
|
||||
};
|
||||
|
||||
if constexpr (IsBlockScaleSupported) {
|
||||
arguments.epilogue.thread.block_scale_factor_ptr = block_SFD.device_data();
|
||||
arguments.epilogue.thread.norm_constant_ptr = block_Normconst.device_data();
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
// think about how to simplify the gemm3x interface.
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
Tensor tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
Tensor tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
Tensor tensor_SFD = make_tensor(block_reference_SFD.host_data(), layout_SFD);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementCompute, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementCompute, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D), // TensorD
|
||||
decltype(tensor_SFD), // TensorSfD
|
||||
cute::Int<OutputSFVectorSize>,
|
||||
cutlass::reference::host::SfStrategy::SfDGen
|
||||
> epilogue_params {options.alpha, options.beta, tensor_C, tensor_D, tensor_SFD, block_Normconst.at(cutlass::make_Coord(0))};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,545 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates a simple way to instantiate and run a mixed precision blockscaled GEMM on the NVIDIA Blackwell SM100 architecture.
|
||||
This Blackwell SM100 CUTLASS kernel uses the new Block Scaled Tensor Core MMA Instructions (tcgen05.mma.blockscaled) introduced
|
||||
on the Blackwell architecture (sm100a) which have the same throughput compared to fp8 Tensor Core MMA instructions (tcgen05.mma)
|
||||
and 2x throughput compared to fp8 Hopper Tensor Core MMA Instructions (WGMMA) (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Similar to 72a_blackwell_nvfp4_fp32_gemm, this kernel leverages:
|
||||
1. Blockscaled tcgen05.mma instructions.
|
||||
|
||||
2. Per-SM memory called Tensor Memory (TMEM) (Please refer to CUDA 12.8 docs on https://docs.nvidia.com/cuda/).
|
||||
|
||||
3. The extended warp-specialized kernel design introduced in Hopper enabled by use of TMEM
|
||||
which allows us to decouple the execution of MMA and epilogue into separate warps.
|
||||
|
||||
4. A new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution).
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm --m=2048 --n=2048 --k=2048
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
|
||||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
|
||||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
|
||||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
|
||||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
|
||||
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
|
||||
using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
PerSmTileShape_MNK, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutCTag, AlignmentC,
|
||||
ElementD, LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutATag, AlignmentA,
|
||||
ElementB, LayoutBTag, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
LayoutA layout_A;
|
||||
LayoutSFA layout_SFA;
|
||||
StrideB stride_B;
|
||||
LayoutB layout_B;
|
||||
LayoutSFB layout_SFB;
|
||||
StrideC stride_C;
|
||||
LayoutC layout_C;
|
||||
StrideD stride_D;
|
||||
LayoutD layout_D;
|
||||
uint64_t seed;
|
||||
|
||||
// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device
|
||||
// Use cute::Tensor and cute::Layout for iterating thru the matrix elements
|
||||
cutlass::HostTensor<ElementA::DataType, cutlass::layout::PackedVectorLayout> block_A;
|
||||
cutlass::HostTensor<ElementA::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFA;
|
||||
cutlass::HostTensor<ElementB::DataType, cutlass::layout::PackedVectorLayout> block_B;
|
||||
cutlass::HostTensor<ElementB::ScaleFactorType, cutlass::layout::PackedVectorLayout> block_SFB;
|
||||
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> block_C;
|
||||
// Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_D;
|
||||
// Reference Output Tensor
|
||||
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> block_reference_D;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(1024), n(1024), k(1024),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "72c_blackwell_mixed_mxfp8_bf16_gemm\n\n"
|
||||
<< " Blackwell Mxfp8 x Mxfp4 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
using namespace cute;
|
||||
// For SFA and SFB tensors layouts
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A);
|
||||
layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B);
|
||||
layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C);
|
||||
layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D);
|
||||
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1));
|
||||
|
||||
block_A.reset(cutlass::make_Coord(size(layout_A)));
|
||||
block_B.reset(cutlass::make_Coord(size(layout_B)));
|
||||
block_C.reset(cutlass::make_Coord(size(layout_C)));
|
||||
block_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_reference_D.reset(cutlass::make_Coord(size(layout_D)));
|
||||
block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA))));
|
||||
block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB))));
|
||||
|
||||
initialize_block(block_A.host_view(), seed + 2021);
|
||||
initialize_block(block_B.host_view(), seed + 2022);
|
||||
initialize_block(block_C.host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.host_view(), seed + 2025);
|
||||
|
||||
block_A.sync_device();
|
||||
block_B.sync_device();
|
||||
block_C.sync_device();
|
||||
block_SFA.sync_device();
|
||||
block_SFB.sync_device();
|
||||
}
|
||||
|
||||
// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(), stride_A,
|
||||
block_B.device_data(), stride_B,
|
||||
block_SFA.device_data(), layout_SFA,
|
||||
block_SFB.device_data(), layout_SFB
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{options.alpha, options.beta},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
decltype(tensor_A), // TensorA
|
||||
decltype(tensor_SFA), // TensorSfA
|
||||
decltype(tensor_B), // TensorB
|
||||
decltype(tensor_SFB) // TensorSfB
|
||||
> mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C);
|
||||
auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettBlockScalingEpilogueParams<
|
||||
ElementAccumulator, // ElementScalar
|
||||
ElementAccumulator, // ElementAccumulator
|
||||
ElementAccumulator, // ElementCompute
|
||||
decltype(tensor_C), // TensorC
|
||||
decltype(tensor_D) // TensorD
|
||||
> epilogue_params{options.alpha, options.beta, tensor_C, tensor_D};
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
// Comparison
|
||||
block_D.sync_host();
|
||||
bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view());
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0);
|
||||
passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0);
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
46
examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt
Normal file
46
examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,46 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
72a_blackwell_nvfp4_bf16_gemm
|
||||
72a_blackwell_nvfp4_bf16_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
72b_blackwell_nvfp4_nvfp4_gemm
|
||||
72b_blackwell_nvfp4_nvfp4_gemm.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
72c_blackwell_mixed_mxfp8_bf16_gemm
|
||||
72c_blackwell_mixed_mxfp8_bf16_gemm.cu
|
||||
)
|
||||
endif()
|
||||
36
examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt
Normal file
36
examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
73_blackwell_gemm_preferred_cluster
|
||||
blackwell_gemm_preferred_cluster.cu
|
||||
)
|
||||
endif()
|
||||
@ -0,0 +1,541 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture with preferred cluster.
|
||||
|
||||
With the introduction of NVIDIA Compute Capability 9.0, the CUDA programming model introduced
|
||||
an optional hierarchy level known as Thread Block Clusters, which consist of multiple Thread Blocks.
|
||||
While the CUDA programming model has supported the specification of cluster shapes at runtime
|
||||
(Dynamic Clusters) since the Hopper architecture, CUTLASS has only provided support for Static
|
||||
Clusters, meaning that cluster shapes must be defined at compile time.
|
||||
|
||||
Larger cluster shapes can achieve higher TMA multicast but may result in poor SM occupancy due
|
||||
to quantization. For instance, a 2x2 cluster on an 18 SM GPU would only utilize 16 SMs, leaving
|
||||
2 SMs idle.
|
||||
|
||||
Starting with Compute Capability 10.0, the CUDA programming model adds the ability to specify
|
||||
two clusters: preferred cluster and fallback cluster. For brevity, we refer to this as
|
||||
Preferred Clusters. In the previous example, users can now launch an additional 2x1 cluster to
|
||||
utilize the 2 idle SMs.
|
||||
|
||||
With CUTLASS 3.8, in addition to Dynamic Clusters, CUTLASS adds support for Preferred Dynamic Cluster,
|
||||
the ability for users to specify two clusters shapes at runtime.
|
||||
|
||||
Terminology
|
||||
* Static cluster: cluster shape is specified at compile time.
|
||||
* Dynamic cluster: cluster shape is specified at runtime and set by the host.
|
||||
* Preferred cluster: Kernel can be launched with two cluster shapes (preferred and fallback).
|
||||
|
||||
Preferred and fallback cluster shapes are subject to several constraints.
|
||||
* Preferred cluster depth (Z dimension) must be the same as that of fallback cluster.
|
||||
* Fallback cluster shape must evenly divide the preferred cluster shape.
|
||||
* Preferred cluster shape must evenly divide the kernel launch grid shape.
|
||||
|
||||
This example demonstrates how to use the Dynamic Clusters and Preferred Clusters features in
|
||||
CUTLASS 3.x Blackwell SM100 kernels. Users can specify preferred and fallback cluster shapes via GEMM arguments.
|
||||
|
||||
# Example:
|
||||
./73_blackwell_gemm_preferred_cluster" --m=4096 --n=4096 --k=4096 --preferred_cluster_m=4 --preferred_cluster_n=4 --fallback_cluster_m=2 --fallback_cluster_m=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
|
||||
using AtomThrShape_MNK = Shape<_2, _1, _1>;
|
||||
// Shape of the tile computed by each SM
|
||||
using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
|
||||
// Shape of the cluster set to <int,int,_1> to indicate dynamic cluster shape
|
||||
using ClusterShape_MNK = Shape<int,int,_1>;
|
||||
// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that
|
||||
// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2
|
||||
// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1>
|
||||
using KernelSchedule = cute::conditional_t<cute::size(AtomThrShape_MNK{}) == 2,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized2SmSm100,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
PerSmTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
void // <--- Default to cluster launch control (CLC) scheduler
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(4096), n(4096), k(4096),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
preferred_cluster_m(4),
|
||||
preferred_cluster_n(4),
|
||||
fallback_cluster_m(2),
|
||||
fallback_cluster_n(1)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
|
||||
|
||||
if (!validate_cluster_shape()){
|
||||
std::cout << "--Invalid cluster shapes" << std::endl;
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "73_blackwell_gemm_preferred_cluster\n\n"
|
||||
<< " Blackwell FP16 GEMM using preferred cluster.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
|
||||
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
|
||||
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
|
||||
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "Preferred cluster shape cannot be smaller than fallback cluster shape.\n"
|
||||
<< "Preferred cluster shape must be a multiple of fallback cluster shape.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ " << "73_blackwell_gemm_preferred_cluster" << " --m=4096 --n=4096 --k=4096 --preferred_cluster_m=4 --preferred_cluster_n=4 --fallback_cluster_m=2 --fallback_cluster_m=1\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Validate preferred and fallback cluster shapes
|
||||
bool validate_cluster_shape() {
|
||||
if (preferred_cluster_m < fallback_cluster_m || preferred_cluster_n < fallback_cluster_n) {
|
||||
std::cout << "--Preferred cluster cannot be smaller than fallback cluster" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (preferred_cluster_m % fallback_cluster_m != 0 || preferred_cluster_n % fallback_cluster_n != 0) {
|
||||
std::cout << "--Preferred cluster must be a multiple of fallback cluster" << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(cutlass::DeviceAllocation<Element>& block, uint64_t seed=2023) {
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
int run(Options &options) {
|
||||
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << "GEMM with"
|
||||
<< " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k
|
||||
<< " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)"
|
||||
<< " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)"
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "--------------------------------------------------------------------------------" << std::endl;
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 || props.minor != 0) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
37
examples/74_blackwell_gemm_streamk/CMakeLists.txt
Normal file
37
examples/74_blackwell_gemm_streamk/CMakeLists.txt
Normal file
@ -0,0 +1,37 @@
|
||||
|
||||
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
74_blackwell_gemm_streamk
|
||||
blackwell_gemm_streamk.cu
|
||||
)
|
||||
endif()
|
||||
592
examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu
Normal file
592
examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu
Normal file
@ -0,0 +1,592 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM100 architecture with the Stream-K scheduler.
|
||||
|
||||
Stream-K is a GEMM parallelization technique that attempts to reduce load imbalance across SMs
|
||||
by parallelizing certain output tiles across the K mode of the GEMM, without using a static splitting factor.
|
||||
For complete details on Stream-K, please see https://arxiv.org/abs/2301.03598.
|
||||
|
||||
CUTLASS's Stream-K scheduler using the CUTLASS 3.x API is capable of supporting various modes of
|
||||
decomposing a GEMM (referred to as "decomposition modes" in this example):
|
||||
* DataParallel: basic GEMM parallelized spatially via tiling, but without splitting the K mode
|
||||
* SplitK: `split_factor` CTAs compute portions of the K mode for a given output tile and reduce their results
|
||||
* StreamK: parallelizes work according to the stream-K load balancing method described in https://arxiv.org/abs/2301.03598
|
||||
* Heuristic: applies an internal heuristic in attempt to choose the most performant among the three preceding decomposition modes
|
||||
|
||||
Additionally, the Stream-K scheduler supports two different means of performing reductions for
|
||||
decomposition modes that require reduction (SplitK, StreamK, and Heuristic):
|
||||
* Deterministic: Participating CTAs perform reduction in a turnstile fashion in order of the K mode
|
||||
covered by each CTA. This requires a lock to be held exclusively by the CTA that is
|
||||
currently accumulating.
|
||||
* Nondeterministic: Participating CTAs perform reduction atomically to the same workspace (mostly) without locking.
|
||||
Locks are used only to wait for the first CTA to write its partial values (to initialize the
|
||||
workspace), and for all but the final CTA to have accumulated (so that the final CTA can load
|
||||
the accumulated value and accumulate it into registers on top of which the epilogue will
|
||||
be performed). Due to the nondeterminsitic ordering of accumulation, deterministic numeric
|
||||
behavior cannot be guaranteed with this mode (e.g., floating-point rounding error will depend
|
||||
on the order of accumulation)
|
||||
|
||||
This example allows one to try out different decomposition modes, reduction modes, and (when using Split-K) splitting factors.
|
||||
Here are a few examples of usage:
|
||||
# Heuristic mode with deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic
|
||||
|
||||
# Stream-K mode with determinsitic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Deterministic
|
||||
|
||||
# Split-K mode with a splitting factor of 2 and deterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=SplitK --reduction=Deterministic --splits=2
|
||||
|
||||
# Stream-K mode with nondeterministic reduction
|
||||
./74_blackwell_gemm_streamk" --m=256 --n=256 --k=16384 --decomposition=StreamK --reduction=Nondeterministic
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
|
||||
// MMA and Cluster Tile Shapes
|
||||
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0
|
||||
using MmaTileShape_MNK = Shape<_256,_128,_64>;
|
||||
// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
|
||||
using AtomThrShape_MNK = Shape<_2, _1, _1>;
|
||||
// Shape of the tile computed by each SM
|
||||
using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
|
||||
// Shape of the cluster set to <int,int,_1> to indicate dynamic cluster shape
|
||||
using ClusterShape_MNK = Shape<int,int,_1>;
|
||||
// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that
|
||||
// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2
|
||||
// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1>
|
||||
using KernelSchedule = cute::conditional_t<cute::size(AtomThrShape_MNK{}) == 2,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized2SmSm100,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
PerSmTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::StreamKScheduler // <--- Change needed to enable the stream-K scheduler
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideA stride_A;
|
||||
StrideB stride_B;
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
|
||||
using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
DecompositionMode decomposition_mode;
|
||||
ReductionMode reduction_mode;
|
||||
int splits;
|
||||
|
||||
std::unordered_map<DecompositionMode, std::vector<std::string>> dec_mappings = {
|
||||
{DecompositionMode::Heuristic, {"Heuristic", "heuristic", "h", "H", ""}},
|
||||
{DecompositionMode::SplitK, {"SplitK", "split-k", "split-K", "Split-K", "Split-k", "splitk", "Splitk", "splitK", "spk", "SpK", "spK"}},
|
||||
{DecompositionMode::StreamK, {"StreamK", "stream-k", "stream-K", "Stream-K", "Stream-k", "streamk", "Streamk", "streamK", "stk", "StK", "stK"}},
|
||||
{DecompositionMode::DataParallel, {"DataParallel", "data-parallel", "dataparallel", "dp", "DP"}}
|
||||
};
|
||||
|
||||
std::unordered_map<ReductionMode, std::vector<std::string>> red_mappings = {
|
||||
{ReductionMode::Deterministic, {"Deterministic", "deterministic", "d", "D", ""}},
|
||||
{ReductionMode::Nondeterministic, {"Nondeterministic", "nondeterministic", "n", "N"}}
|
||||
};
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
m(256), n(256), k(16384),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
preferred_cluster_m(4),
|
||||
preferred_cluster_n(4),
|
||||
fallback_cluster_m(2),
|
||||
fallback_cluster_n(1),
|
||||
decomposition_mode(DecompositionMode::Heuristic),
|
||||
reduction_mode(ReductionMode::Deterministic),
|
||||
splits(1)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("splits", splits, 1);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
|
||||
|
||||
// Parse decompsition mode
|
||||
std::string decomp_mode;
|
||||
cmd.get_cmd_line_argument("decomposition", decomp_mode);
|
||||
bool found = parse_from_options_map(decomp_mode, dec_mappings, decomposition_mode);
|
||||
if (!found) {
|
||||
std::cout << "--decomposition must be one of Heuristic, SplitK, StreamK, or DataParallel" << std::endl;
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse reduction mode
|
||||
std::string red_mode;
|
||||
cmd.get_cmd_line_argument("reduction", red_mode);
|
||||
found = parse_from_options_map(red_mode, red_mappings, reduction_mode);
|
||||
if (!found) {
|
||||
std::cout << "--reduction must be one of Deterministic and Nondeterministic" << std::endl;
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "74_blackwell_gemm_streamk\n\n"
|
||||
<< " Blackwell FP16 GEMM using a stream-K kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
|
||||
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
|
||||
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
|
||||
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
|
||||
<< " --decomposition=<str> Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n"
|
||||
<< " --reduction=<str> Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "74_blackwell_gemm_streamk" << " --m=256 --n=256 --k=16384 --decomposition=Heuristic --reduction=Deterministic \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * m * n * k;
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
|
||||
std::string decomposition_mode_str() const {
|
||||
return dec_mappings.at(decomposition_mode).at(0);
|
||||
}
|
||||
|
||||
std::string reduction_mode_str() const {
|
||||
return red_mappings.at(reduction_mode).at(0);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
bool parse_from_options_map(std::string val, std::unordered_map<T, std::vector<std::string>> options, T& result) const {
|
||||
for (const auto & [key, values] : options) {
|
||||
if (std::find(values.begin(), values.end(), val) != values.end()) {
|
||||
result = key;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(cutlass::DeviceAllocation<Element>& block, uint64_t seed=2023) {
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1});
|
||||
stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1});
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1});
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1});
|
||||
|
||||
block_A.reset(options.m * options.k);
|
||||
block_B.reset(options.k * options.n);
|
||||
block_C.reset(options.m * options.n);
|
||||
block_D.reset(options.m * options.n);
|
||||
block_ref_D.reset(options.m * options.n);
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{options.m, options.n, options.k, 1},
|
||||
{block_A.get(), stride_A, block_B.get(), stride_B},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
|
||||
|
||||
arguments.scheduler.splits = options.splits;
|
||||
arguments.scheduler.decomposition_mode = options.decomposition_mode;
|
||||
arguments.scheduler.reduction_mode = options.reduction_mode;
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{options.m, options.n, options.k},
|
||||
ElementAccumulator(options.alpha),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
int run(Options &options) {
|
||||
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << "Stream-K GEMM with"
|
||||
<< " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k
|
||||
<< " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)"
|
||||
<< " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)\n"
|
||||
<< " Decomposition_mode=" << options.decomposition_mode_str()
|
||||
<< " Split_count=" << options.splits
|
||||
<< " Reduction_mode=" << options.reduction_mode_str()
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "--------------------------------------------------------------------------------" << std::endl;
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
// and must have compute capability at least 100.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
813
examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu
Normal file
813
examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu
Normal file
@ -0,0 +1,813 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM100 TensorOp-based warp-specialized kernel.
|
||||
For this example all scheduling work is performed on the device.
|
||||
The new feature showcased in this example is device-side modification of TMA descriptors
|
||||
to move between groups/problem_count (represented by groups).
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
|
||||
To run this example for a set of problems using the benchmark option:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
|
||||
// Runtime Cluster Shape
|
||||
using ClusterShape = Shape<int32_t,int32_t,_1>;
|
||||
// For Static Cluster Shape:
|
||||
// using ClusterShape = Shape<_2,_1,_1>; // for example
|
||||
// using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config
|
||||
// using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder
|
||||
// using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder
|
||||
|
||||
// Different configs for 1SM and 2SM MMA kernel
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128,_256,Int<128 / sizeof(ElementA)>>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
|
||||
using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{}));
|
||||
};
|
||||
|
||||
struct MMA2SMConfig {
|
||||
using MmaTileShape = Shape<_256,_256,Int<128 / sizeof(ElementA)>>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
|
||||
using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{}));
|
||||
};
|
||||
|
||||
template <typename ScheduleConfig>
|
||||
struct GivenGemmSchedule {
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
typename ScheduleConfig::OutputTileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
typename ScheduleConfig::EpilogueSchedule,
|
||||
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename ScheduleConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
using GemmKernel1SM = GivenGemmSchedule<MMA1SMConfig>::GemmKernel;
|
||||
using Gemm1SM = GivenGemmSchedule<MMA1SMConfig>::Gemm;
|
||||
using Gemm = Gemm1SM;
|
||||
|
||||
using GemmKernel2SM = GivenGemmSchedule<MMA2SMConfig>::GemmKernel;
|
||||
using Gemm2SM = GivenGemmSchedule<MMA2SMConfig>::Gemm;
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
std::vector<int64_t> offset_B;
|
||||
std::vector<int64_t> offset_C;
|
||||
std::vector<int64_t> offset_D;
|
||||
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
|
||||
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
int iterations = 10;
|
||||
int m = 1024, n = 2048, k = 512, groups = 10;
|
||||
dim3 cluster_shape = dim3(4,2,1);
|
||||
dim3 cluster_shape_fallback = dim3(2,1,1);
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::AlongM;
|
||||
int max_sm_count = INT_MAX;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
cmd.get_cmd_line_argument("cluster_m", cluster_shape.x);
|
||||
cmd.get_cmd_line_argument("cluster_n", cluster_shape.y);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y);
|
||||
cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster_order = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster_order = RasterOrderOptions::AlongM;
|
||||
}
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_host.reserve(groups);
|
||||
|
||||
for (int i = groups; i > 0; i--) {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "75_blackwell_grouped_gemm\n\n"
|
||||
<< " Blackwell FP8 Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --cluster_m=<int> and --cluster_n=<int> Sets the X,Y dims of the preferred cluster shape\n"
|
||||
<< " --cluster_fallback_m=<int> and --cluster_fallback_n=<int> Sets the X,Y dims of the fallback cluster shape\n\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "75_blackwell_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = uint64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes_host) {
|
||||
fmas += static_cast<uint64_t>(get<0>(problem)) *
|
||||
static_cast<uint64_t>(get<1>(problem)) *
|
||||
static_cast<uint64_t>(get<2>(problem));
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = static_cast<Element>(2);
|
||||
scope_min = static_cast<Element>(-2);
|
||||
} else {
|
||||
scope_max = static_cast<Element>(8);
|
||||
scope_min = static_cast<Element>(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(const Options &options) {
|
||||
int64_t total_elements_A = 0;
|
||||
int64_t total_elements_B = 0;
|
||||
int64_t total_elements_C = 0;
|
||||
int64_t total_elements_D = 0;
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
offset_A.push_back(total_elements_A);
|
||||
offset_B.push_back(total_elements_B);
|
||||
offset_C.push_back(total_elements_C);
|
||||
offset_D.push_back(total_elements_D);
|
||||
|
||||
int64_t elements_A = M * K;
|
||||
int64_t elements_B = K * N;
|
||||
int64_t elements_C = M * N;
|
||||
int64_t elements_D = M * N;
|
||||
|
||||
total_elements_A += elements_A;
|
||||
total_elements_B += elements_B;
|
||||
total_elements_C += elements_C;
|
||||
total_elements_D += elements_D;
|
||||
|
||||
stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
|
||||
uint64_t seed = 2020;
|
||||
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count);
|
||||
|
||||
if (!is_static_v<ClusterShape>) {
|
||||
if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 &&
|
||||
(options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) {
|
||||
std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl;
|
||||
}
|
||||
hw_info.cluster_shape = options.cluster_shape;
|
||||
hw_info.cluster_shape_fallback = options.cluster_shape_fallback;
|
||||
}
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
// If alpha/beta are provided (via cmd line args) and are scalar, then same alpha/beta applies to all batches.
|
||||
// If pointers to alpha/beta are provided, then alpha/beta can differ between batches/groups.
|
||||
if (options.alpha != FLT_MAX){
|
||||
// Single alpha for all groups
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
// Only one alpha per each group
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
}
|
||||
if (options.beta != FLT_MAX) {
|
||||
// Single beta for all groups
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dBeta = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// Only one beta per each group
|
||||
fusion_args.dBeta = {_0{}, _0{}, 1};
|
||||
}
|
||||
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = options.raster_order;
|
||||
|
||||
if (host_problem_shapes_available) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
bool passed = true;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K}));
|
||||
cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N}));
|
||||
cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{M, N, K},
|
||||
ElementAccumulator(alpha_host.at(i)),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(beta_host.at(i)),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N);
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 ||
|
||||
((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)
|
||||
)
|
||||
) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running kernel with 1SM MMA config:" << std::endl;
|
||||
run<Gemm1SM>(options, false /*host_problem_shapes_available*/);
|
||||
std::cout << "Running kernel with 2SM MMA config:" << std::endl;
|
||||
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,953 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
/*! \file
|
||||
\brief Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM100 TensorOp-based warp-specialized kernel
|
||||
for narrow precisions (FP4) with Scale Factors (In and Out).
|
||||
For this example all scheduling work is performed on the device.
|
||||
The new feature showcased in this example is device-side modification of TMA descriptors
|
||||
to move between groups/problem_count (represented by groups).
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device
|
||||
|
||||
To run this example:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm_block_scaled/75_blackwell_grouped_gemm_block_scaled --m=2048 --n=2048 --k=2048 --groups=10
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
|
||||
To run this example for a set of problems using the benchmark option:
|
||||
|
||||
$ ./examples/75_blackwell_grouped_gemm_block_scaled/75_blackwell_grouped_gemm_block_scaled --benchmark=./test_benchmark.txt
|
||||
|
||||
Where the test_benchmark.txt may look as such:
|
||||
0 256x512x128
|
||||
1 256x512x512
|
||||
2 512x256x128
|
||||
3 256x256x128
|
||||
4 256x512x1024
|
||||
5 1024x512x128 and so on
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
|
||||
#include "helper.h"
|
||||
using namespace cute;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands
|
||||
using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands
|
||||
using ElementC = cutlass::half_t; // Element type for C matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<ElementInput>; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 32; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<ElementInput>; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 32; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = ElementC; // Element type for D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Alignment of D matrix in units of elements (up to 16 bytes)
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
|
||||
// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands
|
||||
constexpr int OutputSFVectorSize = 16;
|
||||
using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor<
|
||||
cutlass::epilogue::thread::SiLu,
|
||||
OutputSFVectorSize,
|
||||
ElementD,
|
||||
ElementAccumulator,
|
||||
ElementSF,
|
||||
LayoutC,
|
||||
ElementC>;
|
||||
|
||||
// Core kernel configurations
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
|
||||
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
|
||||
// Runtime Cluster Shape
|
||||
using ClusterShape = Shape<int32_t,int32_t,_1>;
|
||||
/* // For Static Cluster Shape:
|
||||
use ClusterShape = Shape<_2,_1,_1> for example
|
||||
using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config
|
||||
using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder
|
||||
using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder
|
||||
*/
|
||||
|
||||
// Different configs for 1SM and 2SM MMA kernel
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128,_256,_256>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
|
||||
using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{}));
|
||||
};
|
||||
|
||||
struct MMA2SMConfig {
|
||||
using MmaTileShape = Shape<_256,_256,_256>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
|
||||
using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{}));
|
||||
};
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
typename MMA1SMConfig::OutputTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutC *, AlignmentD,
|
||||
typename MMA1SMConfig::EpilogueSchedule
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA1SMConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMA1SMConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = Gemm1SM;
|
||||
|
||||
using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass,
|
||||
typename MMA2SMConfig::OutputTileShape, ClusterShape,
|
||||
Shape<_128,_64>,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementD, LayoutC *, AlignmentD,
|
||||
typename MMA2SMConfig::EpilogueSchedule
|
||||
// , FusionOperation // Enable for SF Output
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename MMA2SMConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMA2SMConfig::KernelSchedule
|
||||
>::CollectiveOp;
|
||||
using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop2SM,
|
||||
CollectiveEpilogue2SM
|
||||
>;
|
||||
using Gemm2SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel2SM>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm100BlockScaledOutputConfig = cutlass::detail::Sm100BlockScaledOutputConfig<
|
||||
OutputSFVectorSize,
|
||||
cute::is_same_v<typename FusionOperation::GmemLayoutTagScalefactor,
|
||||
cutlass::layout::RowMajor> ? cute::UMMA::Major::K : cute::UMMA::Major::MN
|
||||
>;
|
||||
using OutputSFAtom = typename Sm100BlockScaledOutputConfig::SfAtom;
|
||||
using LayoutSFD = typename Sm100BlockScaledOutputConfig::LayoutSF;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<StrideA> stride_A_host;
|
||||
std::vector<StrideB> stride_B_host;
|
||||
std::vector<LayoutSFA> layout_SFA_host;
|
||||
std::vector<LayoutSFA> layout_SFB_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
using HostTensorA = cutlass::HostTensor<typename Gemm::ElementA, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorB = cutlass::HostTensor<typename Gemm::ElementB, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorSF = cutlass::HostTensor<typename Gemm::GemmKernel::ElementSF, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorC = cutlass::HostTensor<typename Gemm::ElementC, cutlass::layout::PackedVectorLayout>;
|
||||
using HostTensorD = cutlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, cutlass::layout::PackedVectorLayout>;
|
||||
std::vector<HostTensorA> block_A;
|
||||
std::vector<HostTensorB> block_B;
|
||||
std::vector<HostTensorSF> block_SFA;
|
||||
std::vector<HostTensorSF> block_SFB;
|
||||
std::vector<HostTensorC> block_C;
|
||||
std::vector<HostTensorD> block_D;
|
||||
std::vector<HostTensorSF> block_SFD;
|
||||
std::vector<HostTensorD> block_ref_D;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
|
||||
cutlass::DeviceAllocation<const typename Gemm::GemmKernel::ElementSF *> ptr_SFA;
|
||||
cutlass::DeviceAllocation<const typename Gemm::GemmKernel::ElementSF *> ptr_SFB;
|
||||
cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::GemmKernel::ElementSF *> ptr_SFD;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
|
||||
|
||||
cutlass::DeviceAllocation<StrideA> stride_A;
|
||||
cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<LayoutSFA> layout_SFA;
|
||||
cutlass::DeviceAllocation<LayoutSFB> layout_SFB;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
// A matrix wide constant value to scale the output matrix
|
||||
// Avoids generating small FP4 values.
|
||||
// NormConst is a single device-side constant value, its not per-batch or per-group
|
||||
cutlass::DeviceAllocation<ElementAccumulator> norm_constant_device;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
template <typename T>
|
||||
auto make_iterator(T* ptr) {
|
||||
using namespace cute;
|
||||
if constexpr (cute::is_subbyte_v<T>) {
|
||||
return subbyte_iterator<T>(ptr);
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool verification = true;
|
||||
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
float norm_constant = 1.0;
|
||||
int iterations = 10;
|
||||
int m = 1024, n = 2048, k = 512, groups = 10;
|
||||
dim3 cluster_shape = dim3(2,1,1);
|
||||
dim3 cluster_shape_fallback = dim3(2,1,1);
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::AlongN;
|
||||
int max_sm_count = INT_MAX;
|
||||
std::string benchmark_path;
|
||||
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
|
||||
int const tma_alignment_bits = 128;
|
||||
int const alignment = tma_alignment_bits / cutlass::sizeof_bits<ElementInput>::value;
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
if (cmd.check_cmd_line_flag("no-verif")) {
|
||||
verification = false;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", m);
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0));
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
cmd.get_cmd_line_argument("cluster_m", cluster_shape.x);
|
||||
cmd.get_cmd_line_argument("cluster_n", cluster_shape.y);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x);
|
||||
cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y);
|
||||
cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX);
|
||||
|
||||
// Decide how to initialize the problems
|
||||
if (!benchmark_path.empty()) {
|
||||
if (!benchmark_problems()) {
|
||||
problem_sizes_host.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
char raster_char;
|
||||
cmd.get_cmd_line_argument("raster", raster_char);
|
||||
|
||||
if (raster_char == 'N' || raster_char == 'n') {
|
||||
raster_order = RasterOrderOptions::AlongN;
|
||||
}
|
||||
else if (raster_char == 'M' || raster_char == 'm') {
|
||||
raster_order = RasterOrderOptions::AlongM;
|
||||
}
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes_host.reserve(groups);
|
||||
|
||||
for (int i = groups; i > 0; i--) {
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (n < 1) {
|
||||
n = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
problem_sizes_host.push_back({m, n, k});
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a benchmark
|
||||
bool benchmark_problems() {
|
||||
std::ifstream file(benchmark_path);
|
||||
if (!file.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (file.good()) {
|
||||
|
||||
int idx = -1;
|
||||
std::string extent_str;
|
||||
|
||||
file >> idx >> extent_str;
|
||||
|
||||
if (idx < 0 || extent_str.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord extent;
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
cutlass::CommandLine::tokenize(tokens, extent_str, 'x');
|
||||
|
||||
for (int i = 0; i < int(tokens.size()); ++i) {
|
||||
int x = std::atoi(tokens.at(i).c_str());
|
||||
|
||||
// round up
|
||||
if (x % alignment) {
|
||||
x += (alignment - (x % alignment));
|
||||
}
|
||||
|
||||
extent.at(i) = x;
|
||||
}
|
||||
|
||||
if (extent.product()) {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "75_blackwell_grouped_gemm_block_scaled\n\n"
|
||||
<< " Blackwell Block Scaled Narrow Precision Grouped GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
|
||||
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
|
||||
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
|
||||
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --norm_constant=<f32> Epilogue scalar normalization constant for the output matrix\n\n"
|
||||
<< " --cluster_m=<int> and --cluster_n=<int> Sets the X,Y dims of the preferred cluster shape\n"
|
||||
<< " --cluster_fallback_m=<int> and --cluster_fallback_n=<int> Sets the X,Y dims of the fallback cluster shape\n\n"
|
||||
<< " --raster=<char> CTA Rasterization direction (N for along N, M for along M)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
|
||||
<< " --benchmark=<str> Executes a benchmark problem size\n"
|
||||
<< " --max_sm_count=<int> Run kernels using only these number of SMs\n"
|
||||
<< " --no-verif Do not run (host-side) verification kernels\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "75_blackwell_grouped_gemm_block_scaled" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
|
||||
{
|
||||
// Number of real-valued multiply-adds
|
||||
uint64_t fmas = uint64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes_host) {
|
||||
fmas += static_cast<uint64_t>(get<0>(problem)) *
|
||||
static_cast<uint64_t>(get<1>(problem)) *
|
||||
static_cast<uint64_t>(get<2>(problem));
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms = 0.0;
|
||||
double gflops = 0.0;
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t error = cudaSuccess;
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_block(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
uint64_t seed) {
|
||||
|
||||
double scope_max, scope_min;
|
||||
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if constexpr (bits_input == 1) {
|
||||
scope_max = 2;
|
||||
scope_min = 0;
|
||||
}
|
||||
else if constexpr (bits_input <= 6) {
|
||||
scope_max = 2;
|
||||
scope_min = -2;
|
||||
}
|
||||
else if constexpr (bits_input <= 8) {
|
||||
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>) {
|
||||
scope_max = 4;
|
||||
scope_min = 1;
|
||||
}
|
||||
else {
|
||||
scope_max = 1;
|
||||
scope_min = -1;
|
||||
}
|
||||
}
|
||||
else{
|
||||
scope_max = 4;
|
||||
scope_min = -4;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Allocates device-side data
|
||||
void allocate(const Options &options) {
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
|
||||
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
|
||||
|
||||
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
stride_A_host.push_back(stride_A);
|
||||
stride_B_host.push_back(stride_B);
|
||||
layout_SFA_host.push_back(layout_SFA);
|
||||
layout_SFB_host.push_back(layout_SFB);
|
||||
stride_C_host.push_back(stride_C);
|
||||
stride_D_host.push_back(stride_D);
|
||||
|
||||
block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A))));
|
||||
block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B))));
|
||||
block_SFA.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFA)))));
|
||||
block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB)))));
|
||||
block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C))));
|
||||
block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
|
||||
block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD)))));
|
||||
block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D))));
|
||||
}
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
void initialize(const Options &options) {
|
||||
uint64_t seed = 2020;
|
||||
problem_sizes.reset(options.groups);
|
||||
problem_sizes.copy_from_host(options.problem_sizes_host.data());
|
||||
|
||||
//
|
||||
// Assign pointers
|
||||
//
|
||||
|
||||
std::vector<typename Gemm::ElementA *> ptr_A_host(options.groups);
|
||||
std::vector<typename Gemm::ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<typename Gemm::GemmKernel::ElementSF *> ptr_SFA_host(options.groups);
|
||||
std::vector<typename Gemm::GemmKernel::ElementSF *> ptr_SFB_host(options.groups);
|
||||
std::vector<typename Gemm::ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D_host(options.groups);
|
||||
std::vector<typename Gemm::GemmKernel::ElementSF *> ptr_SFD_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
|
||||
initialize_block(block_A.at(i).host_view(), seed + 2021);
|
||||
initialize_block(block_B.at(i).host_view(), seed + 2022);
|
||||
initialize_block(block_C.at(i).host_view(), seed + 2023);
|
||||
initialize_block(block_SFA.at(i).host_view(), seed + 2024);
|
||||
initialize_block(block_SFB.at(i).host_view(), seed + 2025);
|
||||
|
||||
block_A.at(i).sync_device();
|
||||
block_B.at(i).sync_device();
|
||||
block_C.at(i).sync_device();
|
||||
block_SFA.at(i).sync_device();
|
||||
block_SFB.at(i).sync_device();
|
||||
|
||||
ptr_A_host.at(i) = block_A.at(i).device_data();
|
||||
ptr_B_host.at(i) = block_B.at(i).device_data();
|
||||
ptr_SFA_host.at(i) = block_SFA.at(i).device_data();
|
||||
ptr_SFB_host.at(i) = block_SFB.at(i).device_data();
|
||||
ptr_C_host.at(i) = block_C.at(i).device_data();
|
||||
ptr_D_host.at(i) = block_D.at(i).device_data();
|
||||
ptr_SFD_host.at(i) = block_SFD.at(i).device_data();
|
||||
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
ptr_A.copy_from_host(ptr_A_host.data());
|
||||
|
||||
ptr_B.reset(options.groups);
|
||||
ptr_B.copy_from_host(ptr_B_host.data());
|
||||
|
||||
ptr_SFA.reset(options.groups);
|
||||
ptr_SFA.copy_from_host(ptr_SFA_host.data());
|
||||
|
||||
ptr_SFB.reset(options.groups);
|
||||
ptr_SFB.copy_from_host(ptr_SFB_host.data());
|
||||
|
||||
ptr_C.reset(options.groups);
|
||||
ptr_C.copy_from_host(ptr_C_host.data());
|
||||
|
||||
ptr_D.reset(options.groups);
|
||||
ptr_D.copy_from_host(ptr_D_host.data());
|
||||
|
||||
ptr_SFD.reset(options.groups);
|
||||
ptr_SFD.copy_from_host(ptr_SFD_host.data());
|
||||
|
||||
stride_A.reset(options.groups);
|
||||
stride_A.copy_from_host(stride_A_host.data());
|
||||
|
||||
stride_B.reset(options.groups);
|
||||
stride_B.copy_from_host(stride_B_host.data());
|
||||
|
||||
layout_SFA.reset(options.groups);
|
||||
layout_SFA.copy_from_host(layout_SFA_host.data());
|
||||
|
||||
layout_SFB.reset(options.groups);
|
||||
layout_SFB.copy_from_host(layout_SFB_host.data());
|
||||
|
||||
stride_C.reset(options.groups);
|
||||
stride_C.copy_from_host(stride_C_host.data());
|
||||
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
|
||||
norm_constant_device.reset(1);
|
||||
norm_constant_device.copy_from_host(&options.norm_constant);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
template <typename Gemm>
|
||||
typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count);
|
||||
|
||||
if (!is_static_v<ClusterShape>) {
|
||||
if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 &&
|
||||
(options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) {
|
||||
std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl;
|
||||
}
|
||||
hw_info.cluster_shape = options.cluster_shape;
|
||||
hw_info.cluster_shape_fallback = options.cluster_shape_fallback;
|
||||
}
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
|
||||
// If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
if (options.alpha != FLT_MAX){
|
||||
// Single alpha for all groups
|
||||
fusion_args.alpha = options.alpha;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.alpha_ptr_array = alpha_device.get();
|
||||
// Only one alpha per each group
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
}
|
||||
if (options.beta != FLT_MAX) {
|
||||
// Single beta for all groups
|
||||
fusion_args.beta = options.beta;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
fusion_args.dBeta = {_0{}, _0{}, 0};
|
||||
}
|
||||
else {
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.beta_ptr_array = beta_device.get();
|
||||
// Only one beta per each group
|
||||
fusion_args.dBeta = {_0{}, _0{}, 1};
|
||||
}
|
||||
// Output Block SF
|
||||
// fusion_args.block_scale_factor_ptr = ptr_SFD.get(); // Enable for SF Output
|
||||
// fusion_args.norm_constant_ptr = norm_constant_device.get(); // Enable for SF Output
|
||||
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = options.raster_order;
|
||||
|
||||
if (host_problem_shapes_available) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(),
|
||||
ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()},
|
||||
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info, scheduler
|
||||
};
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
using namespace cute;
|
||||
bool passed = true;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
auto problem = options.problem_sizes_host.at(i);
|
||||
auto M = get<0>(problem);
|
||||
auto N = get<1>(problem);
|
||||
auto K = get<2>(problem);
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
|
||||
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
|
||||
auto layout_A = make_layout(make_shape(M, K, 1), stride_A);
|
||||
auto layout_B = make_layout(make_shape(N, K, 1), stride_B);
|
||||
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
|
||||
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
auto layout_SFD = Sm100BlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1));
|
||||
|
||||
// Create the arguments for host reference implementation
|
||||
Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A);
|
||||
Tensor tensor_SFA = make_tensor(block_SFA.at(i).host_data(), layout_SFA);
|
||||
Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B);
|
||||
Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB);
|
||||
cutlass::reference::host::GettBlockScalingMainloopParams<ElementAccumulator,
|
||||
decltype(tensor_A),
|
||||
decltype(tensor_SFA),
|
||||
decltype(tensor_B),
|
||||
decltype(tensor_SFB)
|
||||
>
|
||||
mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
|
||||
|
||||
auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C);
|
||||
auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D);
|
||||
|
||||
cutlass::reference::host::GettEpilogueParams<
|
||||
float, float,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
decltype(tensor_C), decltype(tensor_ref_D)
|
||||
> epilogue_params{};
|
||||
|
||||
epilogue_params.C = tensor_C;
|
||||
epilogue_params.D = tensor_ref_D;
|
||||
epilogue_params.alpha = alpha_host.at(i);
|
||||
epilogue_params.beta = beta_host.at(i);
|
||||
|
||||
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
||||
|
||||
block_D.at(i).sync_host();
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view());
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options<Gemm>(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
if (options.verification) {
|
||||
std::cout << " Host-side verification is now running - may be very slow for large cases." << std::endl;
|
||||
result.passed = verify(options);
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::cout << " Verfication is turned off for this run." << std::endl;
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(gemm.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average setup and runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||
if (__CUDACC_VER_MAJOR__ < 12 ||
|
||||
((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)
|
||||
)
|
||||
) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer.\n";
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (!(props.major == 10 && props.minor == 0)) {
|
||||
std::cerr
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
std::cout << "Running kernel with 1SM MMA config:" << std::endl;
|
||||
run<Gemm1SM>(options, false /*host_problem_shapes_available*/);
|
||||
std::cout << "Running kernel with 2SM MMA config:" << std::endl;
|
||||
run<Gemm2SM>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
88
examples/75_blackwell_grouped_gemm/CMakeLists.txt
Normal file
88
examples/75_blackwell_grouped_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
|
||||
|
||||
|
||||
set(TEST_RANDOM --iterations=0) # Random problem sizes
|
||||
set(TEST_RANDOM_LARGE_GROUP --groups=50 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=50 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=51 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
75_blackwell_grouped_gemm
|
||||
75_blackwell_grouped_gemm.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
75_blackwell_grouped_gemm_block_scaled
|
||||
75_blackwell_grouped_gemm_block_scaled.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_RANDOM
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
)
|
||||
endif()
|
||||
534
examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu
Normal file
534
examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu
Normal file
@ -0,0 +1,534 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Simple dgrad convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a dgrad convolution kernel using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of dgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Weight/Filter (KTRSC) = Activation (NDHWC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Weight/Filter, Matrix C = Activation
|
||||
|
||||
This example instantiates a simple dgrad kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16.
|
||||
Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/76_blackwell_conv/76_blackwell_conv_dgrad --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0
|
||||
--pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/conv/dispatch_policy.hpp"
|
||||
#include "cutlass/conv/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/conv/device/conv_universal_adapter.hpp"
|
||||
#include "cutlass/conv/kernel/conv_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Activation matrix configuration
|
||||
using ElementAct = half_t; // Element type for activation matrix
|
||||
constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Weight/Filter matrix configuration
|
||||
using ElementFlt = half_t; // Element type for weight/filter matrix operand
|
||||
constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Xformed activation matrix configuration
|
||||
using ElementXformedAct = half_t; // Element type for xformed activation matrix operand
|
||||
constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits<ElementXformedAct>::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Layout of matrix A/B/C in gemm's perspecitive.
|
||||
using LayoutA = cutlass::layout::TensorNDHWC;
|
||||
using LayoutB = cutlass::layout::TensorNDHWC;
|
||||
using LayoutC = cutlass::layout::TensorNDHWC;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kDgrad; // Convolution operation
|
||||
|
||||
// Kernel Perf config
|
||||
using TileShape = Shape<_128,_128,Shape<_64>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementAct, LayoutC, AlignmentAct,
|
||||
ElementAct, LayoutC, AlignmentAct,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ConvOp,
|
||||
ElementXformedAct, LayoutA, AlignmentXformedAct,
|
||||
ElementFlt, LayoutB, AlignmentFlt,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::conv::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using ProblemShape=cutlass::conv::ConvProblemShape<ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
|
||||
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;
|
||||
|
||||
using StrideC = typename Conv::ConvKernel::StrideC;
|
||||
using StrideD = typename Conv::ConvKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_A;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_B;
|
||||
cutlass::DeviceAllocation<ElementAct> block_C;
|
||||
cutlass::DeviceAllocation<ElementAct> block_D;
|
||||
cutlass::DeviceAllocation<ElementAct> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int n, d, h, w, c, k, t, r, s, z, p, q;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3),
|
||||
pad_d(0), pad_h(1), pad_w(1),
|
||||
stride_d(1), stride_h(1), stride_w(1),
|
||||
dilation_d(1), dilation_h(1), dilation_w(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("d", d);
|
||||
cmd.get_cmd_line_argument("h", h);
|
||||
cmd.get_cmd_line_argument("w", w);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("t", t);
|
||||
cmd.get_cmd_line_argument("r", r);
|
||||
cmd.get_cmd_line_argument("s", s);
|
||||
cmd.get_cmd_line_argument("pad_d", pad_d);
|
||||
cmd.get_cmd_line_argument("pad_h", pad_h);
|
||||
cmd.get_cmd_line_argument("pad_w", pad_w);
|
||||
cmd.get_cmd_line_argument("stride_d", stride_d);
|
||||
cmd.get_cmd_line_argument("stride_h", stride_h);
|
||||
cmd.get_cmd_line_argument("stride_w", stride_w);
|
||||
cmd.get_cmd_line_argument("dilation_d", dilation_d);
|
||||
cmd.get_cmd_line_argument("dilation_h", dilation_h);
|
||||
cmd.get_cmd_line_argument("dilation_w", dilation_w);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
// Calculate z,p,q based on inputs.
|
||||
z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
|
||||
p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
|
||||
q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "76_blackwell_conv_dgrad\n\n"
|
||||
<< " Blackwell FP16 dgrad convolution using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --n=<int> Sets the batch size of the Activation\n"
|
||||
<< " --d=<int> Sets the depth size of the Activation\n"
|
||||
<< " --h=<int> Sets the height of the Activation\n"
|
||||
<< " --w=<int> Sets the width of the Activation\n"
|
||||
<< " --c=<int> Sets the channel size of the Activation\n"
|
||||
<< " --k=<int> Sets the image numbers of the Filter\n"
|
||||
<< " --t=<int> Sets the depth size of the Filter\n"
|
||||
<< " --r=<int> Sets the height of the Filter\n"
|
||||
<< " --s=<int> Sets the width of the Filter\n"
|
||||
<< " --pad_d=<int> Sets the padding size in depth\n"
|
||||
<< " --pad_h=<int> Sets the padding size in height\n"
|
||||
<< " --pad_w=<int> Sets the padding size in width\n"
|
||||
<< " --stride_d=<int> Sets the traversal stride size in depth\n"
|
||||
<< " --stride_h=<int> Sets the traversal stride size in height\n"
|
||||
<< " --stride_w=<int> Sets the traversal stride size in width\n"
|
||||
<< " --dialtion_d=<int> Sets the filter dilation size in depth\n"
|
||||
<< " --dialtion_h=<int> Sets the filter dilation size in height\n"
|
||||
<< " --dialtion_w=<int> Sets the filter dilation size in width\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "76_blackwell_conv_dgrad" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0"
|
||||
<< " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * (n * d * h * w) * c * (t * r * s * k);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the Conv and reference Conv
|
||||
void initialize(const Options &options) {
|
||||
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
// Setup stride_C/D
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
|
||||
block_A.reset(problem_shape.size_A());
|
||||
block_B.reset(problem_shape.size_B());
|
||||
block_C.reset(problem_shape.size_C());
|
||||
block_D.reset(problem_shape.size_C());
|
||||
block_ref_D.reset(problem_shape.size_C());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Conv::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
typename Conv::Arguments arguments{
|
||||
problem_shape,
|
||||
{block_A.get(), block_B.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutC::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Construct Conv3dProblemSize with user defined inputs.
|
||||
cutlass::conv::Conv3dProblemSize problem_size(
|
||||
cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc
|
||||
cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc
|
||||
cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding
|
||||
cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w)
|
||||
cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w)
|
||||
cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk
|
||||
);
|
||||
|
||||
// Launch device reference conv kernel
|
||||
cutlass::reference::device::Conv3dDgrad(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Conv conv;
|
||||
|
||||
// Create a structure of conv kernel arguments suitable for invoking an instance of Conv
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Conv::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(conv.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(conv.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(conv.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size:" << std::endl;
|
||||
std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), ";
|
||||
std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), ";
|
||||
std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Conv>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
534
examples/76_blackwell_conv/76_blackwell_conv_fprop.cu
Normal file
534
examples/76_blackwell_conv/76_blackwell_conv_fprop.cu
Normal file
@ -0,0 +1,534 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Simple fprop convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a fprop convolution kernel using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of fprop convolution kernel is, take 3D convolution as an example:
|
||||
Activation (NDHWC) * Weight/Filter (KTRSC) = Xformed Actication (NZPQK)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Activation, Matrix B = Weight/Filter, Matrix C = Xformed Activation
|
||||
|
||||
This example instantiates a simple fprop kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16.
|
||||
Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/76_blackwell_conv/76_blackwell_conv_fprop --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0
|
||||
--pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/conv/dispatch_policy.hpp"
|
||||
#include "cutlass/conv/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/conv/device/conv_universal_adapter.hpp"
|
||||
#include "cutlass/conv/kernel/conv_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Activation matrix configuration
|
||||
using ElementAct = half_t; // Element type for activation matrix
|
||||
constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Weight/Filter matrix configuration
|
||||
using ElementFlt = half_t; // Element type for weight/filter matrix operand
|
||||
constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Xformed activation matrix configuration
|
||||
using ElementXformedAct = half_t; // Element type for xformed activation matrix operand
|
||||
constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits<ElementXformedAct>::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Layout of matrix A/B/C in gemm's perspecitive.
|
||||
using LayoutA = cutlass::layout::TensorNDHWC;
|
||||
using LayoutB = cutlass::layout::TensorNDHWC;
|
||||
using LayoutC = cutlass::layout::TensorNDHWC;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kFprop; // Convolution operation
|
||||
|
||||
// Kernel Perf config
|
||||
using TileShape = Shape<_128,_128,Shape<_64>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementXformedAct, LayoutC, AlignmentXformedAct,
|
||||
ElementXformedAct, LayoutC, AlignmentXformedAct,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ConvOp,
|
||||
ElementAct, LayoutA, AlignmentAct,
|
||||
ElementFlt, LayoutB, AlignmentFlt,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::conv::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using ProblemShape=cutlass::conv::ConvProblemShape<ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
|
||||
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;
|
||||
|
||||
using StrideC = typename Conv::ConvKernel::StrideC;
|
||||
using StrideD = typename Conv::ConvKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementAct> block_A;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_B;
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_C;
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_D;
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int n, d, h, w, c, k, t, r, s, z, p, q;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3),
|
||||
pad_d(0), pad_h(1), pad_w(1),
|
||||
stride_d(1), stride_h(1), stride_w(1),
|
||||
dilation_d(1), dilation_h(1), dilation_w(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("d", d);
|
||||
cmd.get_cmd_line_argument("h", h);
|
||||
cmd.get_cmd_line_argument("w", w);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("t", t);
|
||||
cmd.get_cmd_line_argument("r", r);
|
||||
cmd.get_cmd_line_argument("s", s);
|
||||
cmd.get_cmd_line_argument("pad_d", pad_d);
|
||||
cmd.get_cmd_line_argument("pad_h", pad_h);
|
||||
cmd.get_cmd_line_argument("pad_w", pad_w);
|
||||
cmd.get_cmd_line_argument("stride_d", stride_d);
|
||||
cmd.get_cmd_line_argument("stride_h", stride_h);
|
||||
cmd.get_cmd_line_argument("stride_w", stride_w);
|
||||
cmd.get_cmd_line_argument("dilation_d", dilation_d);
|
||||
cmd.get_cmd_line_argument("dilation_h", dilation_h);
|
||||
cmd.get_cmd_line_argument("dilation_w", dilation_w);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
// Calculate z,p,q based on inputs.
|
||||
z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
|
||||
p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
|
||||
q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "76_blackwell_conv_fprop\n\n"
|
||||
<< " Blackwell FP16 fprop convolution using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --n=<int> Sets the batch size of the Activation\n"
|
||||
<< " --d=<int> Sets the depth size of the Activation\n"
|
||||
<< " --h=<int> Sets the height of the Activation\n"
|
||||
<< " --w=<int> Sets the width of the Activation\n"
|
||||
<< " --c=<int> Sets the channel size of the Activation\n"
|
||||
<< " --k=<int> Sets the image numbers of the Filter\n"
|
||||
<< " --t=<int> Sets the depth size of the Filter\n"
|
||||
<< " --r=<int> Sets the height of the Filter\n"
|
||||
<< " --s=<int> Sets the width of the Filter\n"
|
||||
<< " --pad_d=<int> Sets the padding size in depth\n"
|
||||
<< " --pad_h=<int> Sets the padding size in height\n"
|
||||
<< " --pad_w=<int> Sets the padding size in width\n"
|
||||
<< " --stride_d=<int> Sets the traversal stride size in depth\n"
|
||||
<< " --stride_h=<int> Sets the traversal stride size in height\n"
|
||||
<< " --stride_w=<int> Sets the traversal stride size in width\n"
|
||||
<< " --dialtion_d=<int> Sets the filter dilation size in depth\n"
|
||||
<< " --dialtion_h=<int> Sets the filter dilation size in height\n"
|
||||
<< " --dialtion_w=<int> Sets the filter dilation size in width\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "76_blackwell_conv_fprop" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0"
|
||||
<< " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * (n * z * p * q) * k * (t * r * s * c);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the Conv and reference Conv
|
||||
void initialize(const Options &options) {
|
||||
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
// Setup stride_C/D
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
|
||||
cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
||||
});
|
||||
|
||||
block_A.reset(problem_shape.size_A());
|
||||
block_B.reset(problem_shape.size_B());
|
||||
block_C.reset(problem_shape.size_C());
|
||||
block_D.reset(problem_shape.size_C());
|
||||
block_ref_D.reset(problem_shape.size_C());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Conv::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
typename Conv::Arguments arguments{
|
||||
problem_shape,
|
||||
{block_A.get(), block_B.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutC::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Construct Conv3dProblemSize with user defined inputs.
|
||||
cutlass::conv::Conv3dProblemSize problem_size(
|
||||
cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc
|
||||
cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc
|
||||
cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding
|
||||
cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w)
|
||||
cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w)
|
||||
cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk
|
||||
);
|
||||
|
||||
// Launch device reference conv kernel
|
||||
cutlass::reference::device::Conv3dFprop(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Conv conv;
|
||||
|
||||
// Create a structure of conv kernel arguments suitable for invoking an instance of Conv
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Conv::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(conv.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(conv.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(conv.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size:" << std::endl;
|
||||
std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), ";
|
||||
std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), ";
|
||||
std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Conv>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
530
examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu
Normal file
530
examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu
Normal file
@ -0,0 +1,530 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Simple wgrad convolution example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
This example demonstrate a simple way to instantiate and run a wgrad convolution kernel using the new CUTLASS 3.0
|
||||
APIs on NVIDIA Blackwell SM100 architecture.
|
||||
|
||||
The basic computation logic of wgrad convolution kernel is, take 3D convolution as an example:
|
||||
Xformed Actication (NZPQK) * Activation (NDHWC) = Weight/Filter (KTRSC)
|
||||
|
||||
where in terms of GEMM perspective,
|
||||
Matrix A = Xformed Activation, Matrix B = Activation, Matrix C = Weight/Filter
|
||||
|
||||
This example instantiates a simple wgrad kernel using TMA + UMMA + Warp Specialized design with input and output types are fp16.
|
||||
Alpha/beta scaling is supported while fusions like relu/bias/per-channel scaling are not supported in this example.
|
||||
|
||||
Usage:
|
||||
|
||||
$ ./examples/76_blackwell_conv/76_blackwell_conv_wgrad --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0
|
||||
--pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1
|
||||
*/
|
||||
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/convnd_problem_shape.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/conv/dispatch_policy.hpp"
|
||||
#include "cutlass/conv/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/conv/device/conv_universal_adapter.hpp"
|
||||
#include "cutlass/conv/kernel/conv_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Activation matrix configuration
|
||||
using ElementAct = half_t; // Element type for activation matrix
|
||||
constexpr int AlignmentAct = 128 / cutlass::sizeof_bits<ElementAct>::value; // Memory access granularity/alignment of activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Weight/Filter matrix configuration
|
||||
using ElementFlt = half_t; // Element type for weight/filter matrix operand
|
||||
constexpr int AlignmentFlt = 128 / cutlass::sizeof_bits<ElementFlt>::value; // Memory access granularity/alignment of weight/filter matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Xformed activation matrix configuration
|
||||
using ElementXformedAct = half_t; // Element type for xformed activation matrix operand
|
||||
constexpr int AlignmentXformedAct = 128 / cutlass::sizeof_bits<ElementXformedAct>::value; // Memory access granularity/alignment of xformed activation matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Layout of matrix A/B/C in gemm's perspecitive.
|
||||
using LayoutA = cutlass::layout::TensorNDHWC;
|
||||
using LayoutB = cutlass::layout::TensorNDHWC;
|
||||
using LayoutC = cutlass::layout::TensorKCSRT;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for internal computation
|
||||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
constexpr cutlass::conv::Operator ConvOp = cutlass::conv::Operator::kWgrad; // Convolution operation
|
||||
|
||||
// Kernel Perf config
|
||||
using TileShape = Shape<_128,Shape<_128>,Shape<_64>>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
|
||||
|
||||
// Build the epilogue
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute,
|
||||
ElementFlt, LayoutC, AlignmentFlt,
|
||||
ElementFlt, LayoutC, AlignmentFlt,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Build the mainloop
|
||||
using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ConvOp,
|
||||
ElementXformedAct, LayoutA, AlignmentXformedAct,
|
||||
ElementAct, LayoutB, AlignmentAct,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::conv::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// Compose into a kernel
|
||||
using ProblemShape=cutlass::conv::ConvProblemShape<ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
|
||||
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
>;
|
||||
|
||||
using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;
|
||||
|
||||
using StrideC = typename Conv::ConvKernel::StrideC;
|
||||
using StrideD = typename Conv::ConvKernel::StrideD;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideC stride_C;
|
||||
StrideD stride_D;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::DeviceAllocation<ElementXformedAct> block_A;
|
||||
cutlass::DeviceAllocation<ElementAct> block_B;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_C;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_D;
|
||||
cutlass::DeviceAllocation<ElementFlt> block_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int n, d, h, w, c, k, t, r, s, z, p, q;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
n(4), d(1), h(8), w(8), c(64), k(64), t(1), r(3), s(3),
|
||||
pad_d(0), pad_h(1), pad_w(1),
|
||||
stride_d(1), stride_h(1), stride_w(1),
|
||||
dilation_d(1), dilation_h(1), dilation_w(1),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("d", d);
|
||||
cmd.get_cmd_line_argument("h", h);
|
||||
cmd.get_cmd_line_argument("w", w);
|
||||
cmd.get_cmd_line_argument("c", c);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("t", t);
|
||||
cmd.get_cmd_line_argument("r", r);
|
||||
cmd.get_cmd_line_argument("s", s);
|
||||
cmd.get_cmd_line_argument("pad_d", pad_d);
|
||||
cmd.get_cmd_line_argument("pad_h", pad_h);
|
||||
cmd.get_cmd_line_argument("pad_w", pad_w);
|
||||
cmd.get_cmd_line_argument("stride_d", stride_d);
|
||||
cmd.get_cmd_line_argument("stride_h", stride_h);
|
||||
cmd.get_cmd_line_argument("stride_w", stride_w);
|
||||
cmd.get_cmd_line_argument("dilation_d", dilation_d);
|
||||
cmd.get_cmd_line_argument("dilation_h", dilation_h);
|
||||
cmd.get_cmd_line_argument("dilation_w", dilation_w);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
// Calculate z,p,q based on inputs.
|
||||
z = 1 + (d + 2 * pad_d - ((t - 1) * dilation_d + 1)) / stride_d;
|
||||
p = 1 + (h + 2 * pad_h - ((r - 1) * dilation_h + 1)) / stride_h;
|
||||
q = 1 + (w + 2 * pad_w - ((s - 1) * dilation_w + 1)) / stride_w;
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "76_blackwell_conv_wgrad\n\n"
|
||||
<< " Blackwell FP16 wgrad convolution using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --n=<int> Sets the batch size of the Activation\n"
|
||||
<< " --d=<int> Sets the depth size of the Activation\n"
|
||||
<< " --h=<int> Sets the height of the Activation\n"
|
||||
<< " --w=<int> Sets the width of the Activation\n"
|
||||
<< " --c=<int> Sets the channel size of the Activation\n"
|
||||
<< " --k=<int> Sets the image numbers of the Filter\n"
|
||||
<< " --t=<int> Sets the depth size of the Filter\n"
|
||||
<< " --r=<int> Sets the height of the Filter\n"
|
||||
<< " --s=<int> Sets the width of the Filter\n"
|
||||
<< " --pad_d=<int> Sets the padding size in depth\n"
|
||||
<< " --pad_h=<int> Sets the padding size in height\n"
|
||||
<< " --pad_w=<int> Sets the padding size in width\n"
|
||||
<< " --stride_d=<int> Sets the traversal stride size in depth\n"
|
||||
<< " --stride_h=<int> Sets the traversal stride size in height\n"
|
||||
<< " --stride_w=<int> Sets the traversal stride size in width\n"
|
||||
<< " --dialtion_d=<int> Sets the filter dilation size in depth\n"
|
||||
<< " --dialtion_h=<int> Sets the filter dilation size in height\n"
|
||||
<< " --dialtion_w=<int> Sets the filter dilation size in width\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "76_blackwell_conv_wgrad" << " --n=4 --d=1 --h=8 --w=8 --c=64 --k=64 --t=1 --r=3 --s=3 --pad_d=0"
|
||||
<< " --pad_h=1 --pad_w=1 --stride_d=1 --stride_h=1 --stride_w=1 --dilation_d=1 --dilation_h=1 --dilation_w=1 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * k * (t * r * s * c) * (n * z * p * q);
|
||||
double gflop = double(flop) / double(1.0e9);
|
||||
return gflop / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Conv setup and evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
bool initialize_block(
|
||||
cutlass::DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023) {
|
||||
|
||||
Element scope_max, scope_min;
|
||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits_input == 1) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(0);
|
||||
} else if (bits_input <= 8) {
|
||||
scope_max = Element(2);
|
||||
scope_min = Element(-2);
|
||||
} else {
|
||||
scope_max = Element(8);
|
||||
scope_min = Element(-8);
|
||||
}
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, scope_max, scope_min, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the Conv and reference Conv
|
||||
void initialize(const Options &options) {
|
||||
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
// Setup stride_C/D
|
||||
stride_C = cutlass::make_cute_packed_stride(StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
|
||||
stride_D = cutlass::make_cute_packed_stride(StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
|
||||
|
||||
block_A.reset(problem_shape.size_A());
|
||||
block_B.reset(problem_shape.size_B());
|
||||
block_C.reset(problem_shape.size_C());
|
||||
block_D.reset(problem_shape.size_C());
|
||||
block_ref_D.reset(problem_shape.size_C());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Conv::Arguments args_from_options(const Options &options)
|
||||
{
|
||||
// Construct ConvProblemShape
|
||||
ProblemShape problem_shape(
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
{options.n, options.d, options.h, options.w, options.c}, // ndhwc
|
||||
{options.k, options.t, options.r, options.s, options.c}, // ktrsc
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding lower (pad_d, pad_h, pad_w)
|
||||
{options.pad_d, options.pad_h, options.pad_w}, // padding upper (pad_d, pad_h, pad_w)
|
||||
{options.stride_d, options.stride_h, options.stride_w}, // stride (stride_d, stride_h, stride_w)
|
||||
{options.dilation_d, options.dilation_h, options.dilation_w}, // dilation (dilation_d, dilation_h, dilation_w)
|
||||
1 // group
|
||||
);
|
||||
|
||||
typename Conv::Arguments arguments{
|
||||
problem_shape,
|
||||
{block_A.get(), block_B.get()},
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
bool verify(const Options &options) {
|
||||
cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.n, options.z, options.p, options.q, options.k}));
|
||||
cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.n, options.d, options.h, options.w, options.c}));
|
||||
cutlass::TensorRef ref_C(block_C.get(), LayoutA::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutA::packed({options.k, options.t, options.r, options.s, options.c}));
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Construct Conv3dProblemSize with user defined inputs.
|
||||
cutlass::conv::Conv3dProblemSize problem_size(
|
||||
cutlass::Tensor5DCoord(options.n, options.d, options.h, options.w, options.c), // ndhwc
|
||||
cutlass::Tensor5DCoord(options.k, options.t, options.r, options.s, options.c), // ktrsc
|
||||
cutlass::make_Coord(options.pad_d, options.pad_h, options.pad_w), // padding
|
||||
cutlass::make_Coord(options.stride_d, options.stride_h, options.stride_w), // stride (stride_d, stride_h, stride_w)
|
||||
cutlass::make_Coord(options.dilation_d, options.dilation_h, options.dilation_w), // dilation (dilation_d, dilation_h, dilation_w)
|
||||
cutlass::Tensor5DCoord(options.n, options.z, options.p, options.q, options.k) // nzpqk
|
||||
);
|
||||
|
||||
// Launch device reference conv kernel
|
||||
cutlass::reference::device::Conv3dWgrad(problem_size, ref_A, ref_B, ref_C, ref_D, options.alpha, options.beta);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size());
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
{
|
||||
initialize(options);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Conv conv;
|
||||
|
||||
// Create a structure of conv kernel arguments suitable for invoking an instance of Conv
|
||||
auto arguments = args_from_options(options);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Conv::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check if the problem size is supported or not
|
||||
CUTLASS_CHECK(conv.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(conv.run());
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = verify(options);
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(conv.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(conv.run());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Problem Size:" << std::endl;
|
||||
std::cout << " Activation(n,d,h,w,c) = (" << options.n << ',' << options.d << ',' << options.h << ',' << options.w << ',' << options.c << "), ";
|
||||
std::cout << " Filter(k,t,r,s,c) = (" << options.k << ',' << options.t << ',' << options.r << ',' << options.s << ',' << options.c << "), ";
|
||||
std::cout << " Xformed Activation(n,z,p,q,k) = (" << options.n << ',' << options.z << ',' << options.p << ',' << options.q << ',' << options.k << ")" << std::endl;
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
||||
// and must have compute capability at least 90.
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
run<Conv>(options);
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
46
examples/76_blackwell_conv/CMakeLists.txt
Normal file
46
examples/76_blackwell_conv/CMakeLists.txt
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
76_blackwell_conv_fprop
|
||||
76_blackwell_conv_fprop.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
76_blackwell_conv_dgrad
|
||||
76_blackwell_conv_dgrad.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
76_blackwell_conv_wgrad
|
||||
76_blackwell_conv_wgrad.cu
|
||||
)
|
||||
endif()
|
||||
990
examples/77_blackwell_fmha/77_blackwell_fmha.cu
Normal file
990
examples/77_blackwell_fmha/77_blackwell_fmha.cu
Normal file
@ -0,0 +1,990 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example implementation of fused multi-head attention for the NVIDIA Blackwell SM100
|
||||
architecture using CUTLASS 3.
|
||||
|
||||
MQA/GQA
|
||||
-------
|
||||
|
||||
The head dimension can be represented as a tuple, where the K/V strides in the
|
||||
first dimension is zero. This has the effect of MQA or GQA.
|
||||
* MHA is (head_size:head_stride).
|
||||
* MQA is (head_size:head_stride) in Q and (head_size:_0) in K and V.
|
||||
* GQA is (grouped_heads,heads_kv):(head_stride,grouped_heads*head_stride) in Q
|
||||
and (grouped_heads,heads_kv):(0,head_stride) in K and V
|
||||
|
||||
Output Scale
|
||||
------------
|
||||
|
||||
The output scale gets passed to the collective mainloop, and is applied
|
||||
using FP32 compute pre-quantization
|
||||
|
||||
Variable Sequence Length
|
||||
------------------------
|
||||
|
||||
For variable sequence length, pass in VariableLength objects
|
||||
(max_seqlen, cumulative_seqlen_ptr) in the problem shape for
|
||||
seqlen Q and KV.
|
||||
|
||||
Support
|
||||
---------
|
||||
|
||||
Right now e4m3 with fp32 compute is using a 256x256 tiling and a head dimension
|
||||
of 128 is supported.
|
||||
|
||||
|
||||
Example usage:
|
||||
$ ./examples/77_blackell_fmha/77_blackell_fmha_fp8 \
|
||||
--b=2048 --h=2048 --d=2048 --q=2048 --k=2048
|
||||
*/
|
||||
|
||||
#define DSHOW(x) print(#x ": "); print(x); print("\n");
|
||||
#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n");
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "reference/fmha_fwd_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "device/fmha.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp"
|
||||
#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp"
|
||||
#include "kernel/fmha_options.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
using namespace cutlass::fmha::collective;
|
||||
using namespace cutlass::fmha;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 1;
|
||||
int h = 1;
|
||||
int h_k = 1;
|
||||
int q = 256;
|
||||
int k = 256;
|
||||
int d = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
|
||||
bool causal = false;
|
||||
bool residual = false;
|
||||
bool varlen = false;
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_k = InitStyle::kRandom;
|
||||
InitStyle init_style_v = InitStyle::kRandom;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
cmd.get_cmd_line_argument("q", q, -1);
|
||||
cmd.get_cmd_line_argument("k", k, -1);
|
||||
if (q == -1) q = k;
|
||||
if (k == -1) k = q;
|
||||
if (q == -1 && k == -1) q = k = defaults.q;
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
std::string mask;
|
||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||
if (mask == "no" || mask == "") {
|
||||
causal = residual = false;
|
||||
if (varlen) {
|
||||
residual = true;
|
||||
}
|
||||
}
|
||||
else if (mask == "causal") {
|
||||
residual = false;
|
||||
causal = true;
|
||||
}
|
||||
else if (mask == "residual") {
|
||||
residual = true;
|
||||
causal = false;
|
||||
}
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-k", init_style_k, init_style_k);
|
||||
get_init_style_argument(cmd, "init-style-v", init_style_v, init_style_v);
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_fmha\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " fused multi-head attention forward-passkernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --q=<int> Sets the Q extent\n"
|
||||
<< " --k=<int> Sets the K extent\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||
<< " --varlen Enables variable sequence length\n"
|
||||
<< " B*Q and B*K become the total sequence length\n"
|
||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||
<< " with the last batch sized to make it fit\n"
|
||||
<< " implies at least residual masking for correctness\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (i % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
double tops_exp2_s = 0;
|
||||
double tbytes_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
bool kIsVarlen,
|
||||
class TileShape,
|
||||
class DispatchPolicy,
|
||||
class ActiveMask,
|
||||
class... KernelOptions
|
||||
>
|
||||
struct FwdRunner {
|
||||
|
||||
#ifdef FP8
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
#endif
|
||||
|
||||
using ElementAccumulatorQK = float;
|
||||
using ElementAccumulatorPV = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
// Q K D (B H)
|
||||
using ProblemShapeRegular = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeVarlen = cute::tuple<VariableLength, VariableLength, int, cute::tuple<cute::tuple<int, int>, int>>;
|
||||
using ProblemShapeType = std::conditional_t<kIsVarlen, ProblemShapeVarlen, ProblemShapeRegular>;
|
||||
|
||||
using StrideQ = cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R B)
|
||||
using StrideK = cute::tuple<int, _1, cute::tuple<cute::tuple<_0, int>, int>>; // K D (H_G H_R B)
|
||||
using StrideV = StrideK;
|
||||
using StrideO = StrideQ;
|
||||
using StrideLSE = cute::tuple<_1, cute::tuple<cute::tuple<int, int>, int>>; // Q (H_G H_R B)
|
||||
|
||||
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, true_type, KernelOptions...>::value;
|
||||
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
|
||||
|
||||
using Mainloop =
|
||||
cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized<
|
||||
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||
TileShape, StrideQ, StrideK, StrideV,
|
||||
ActiveMask
|
||||
>;
|
||||
using Operation = cutlass::fmha::device::FMHA<
|
||||
cutlass::fmha::kernel::Sm100FmhaFwdKernelTmaWarpspecialized<
|
||||
ProblemShapeType,
|
||||
Mainloop,
|
||||
cutlass::fmha::collective::Sm100FmhaFwdEpilogueTmaWarpspecialized<
|
||||
ElementOut, ElementAccumulatorPV,
|
||||
typename Mainloop::TileShapePV,
|
||||
StrideO, StrideLSE
|
||||
>,
|
||||
TileScheduler
|
||||
>>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Initialization
|
||||
StrideQ stride_Q;
|
||||
StrideK stride_K;
|
||||
StrideV stride_V;
|
||||
StrideO stride_O;
|
||||
StrideLSE stride_LSE;
|
||||
uint64_t seed = 0;
|
||||
|
||||
DeviceAllocation<Element> block_Q;
|
||||
DeviceAllocation<Element> block_K;
|
||||
DeviceAllocation<Element> block_V;
|
||||
DeviceAllocation<ElementOut> block_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||
DeviceAllocation<ElementOut> block_ref_O;
|
||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||
|
||||
std::vector<int> cumulative_seqlen_q;
|
||||
std::vector<int> cumulative_seqlen_kv;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
bool verify(const ProblemShapeType& problem_shape) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_Q);
|
||||
|
||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_K);
|
||||
|
||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
||||
select<1,2,3>(problem_shape),
|
||||
stride_V);
|
||||
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
||||
select<0,2,3>(problem_shape),
|
||||
stride_O);
|
||||
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
||||
select<0,3>(problem_shape),
|
||||
stride_LSE);
|
||||
|
||||
fmha_reference(problem_shape, mQ, mK, mV, mO, mLSE, ActiveMask{});
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
||||
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
||||
|
||||
bool passed_LSE = true; // future work
|
||||
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
// if ( ! passed_LSE) {
|
||||
// std::cerr << "failed LSE: max diff " << max_diff
|
||||
// << " mean " << mean_diff << std::endl;
|
||||
// }
|
||||
|
||||
return passed_O && passed_LSE;
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) {
|
||||
int num_batches = get<3,1>(problem_size);
|
||||
|
||||
// generate Q as --b times
|
||||
// gaussian (--Q, --Q / 2) sampled positive
|
||||
// track cumulative
|
||||
std::mt19937 rng(0x202305151552ull);
|
||||
std::normal_distribution<double> dist_q(get<0>(problem_size), get<0>(problem_size) / 2);
|
||||
std::normal_distribution<double> dist_kv(get<1>(problem_size), get<1>(problem_size) / 2);
|
||||
std::cout << "N: " << num_batches << ", Q: " << get<0>(problem_size) << ", KV: " << get<1>(problem_size) << std::endl;
|
||||
|
||||
auto generate_positive_int = [](auto& dist, auto& gen) {
|
||||
int result = 0;
|
||||
do {
|
||||
result = static_cast<int>(dist(gen));
|
||||
} while (result <= 0);
|
||||
return result;
|
||||
};
|
||||
|
||||
cumulative_seqlen_q = {0};
|
||||
cumulative_seqlen_kv = {0};
|
||||
|
||||
int total_seqlen_q = 0;
|
||||
int total_seqlen_kv = 0;
|
||||
int max_seqlen_q = 0;
|
||||
int max_seqlen_kv = 0;
|
||||
|
||||
for (int i = 0; i < num_batches; i++) {
|
||||
int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng);
|
||||
int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng);
|
||||
|
||||
total_seqlen_q += seqlen_q;
|
||||
total_seqlen_kv += seqlen_kv;
|
||||
|
||||
max_seqlen_q = std::max(max_seqlen_q, seqlen_q);
|
||||
max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv);
|
||||
|
||||
cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q);
|
||||
cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv);
|
||||
}
|
||||
std::cout << "Q max: " << max_seqlen_q << " total: " << total_seqlen_q << " vs even " << num_batches * get<0>(problem_size) << std::endl;
|
||||
std::cout << "KV max: " << max_seqlen_kv << " total: " << total_seqlen_kv << " vs even " << num_batches * get<1>(problem_size) << std::endl;
|
||||
|
||||
ProblemShape problem_size_for_init = problem_size;
|
||||
get<3,1>(problem_size_for_init) = 1;
|
||||
get<0>(problem_size_for_init) = total_seqlen_q;
|
||||
get<1>(problem_size_for_init) = total_seqlen_kv;
|
||||
|
||||
ProblemShapeType problem_size_for_launch;
|
||||
|
||||
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
|
||||
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
|
||||
get<2>(problem_size_for_launch) = get<2>(problem_size);
|
||||
get<3>(problem_size_for_launch) = get<3>(problem_size);
|
||||
|
||||
return cute::make_tuple(problem_size_for_init, problem_size_for_launch);
|
||||
}
|
||||
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
|
||||
ProblemShapeType initialize(const Options& options) {
|
||||
int h_r = options.h / options.h_k;
|
||||
assert(options.h % options.h_k == 0);
|
||||
auto problem_shape_in = cute::make_tuple(options.q, options.k, options.d, cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b));
|
||||
|
||||
ProblemShapeType problem_shape;
|
||||
decltype(problem_shape_in) problem_size;
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in);
|
||||
problem_shape = problem_shape_launch;
|
||||
problem_size = problem_shape_init;
|
||||
}
|
||||
else {
|
||||
problem_size = problem_shape_in;
|
||||
problem_shape = problem_shape_in;
|
||||
}
|
||||
|
||||
get<2>(problem_size) = cutlass::round_up(get<2>(problem_size), 8); // alignment
|
||||
|
||||
auto shape_QO = select<0,2,3>(problem_size);
|
||||
auto shape_KV = select<1,2,3>(problem_size);
|
||||
auto shape_LSE = select<0,3>(problem_size);
|
||||
|
||||
int SQ = size<0>(problem_size);
|
||||
int SK = size<1>(problem_size);
|
||||
int D = size<2>(problem_size);
|
||||
int H = size<3,0>(problem_size);
|
||||
int H_K = size<3,0,1>(problem_size);
|
||||
int H_Q = size<3,0,0>(problem_size);
|
||||
int B = size<3,1>(problem_size);
|
||||
|
||||
stride_Q = make_stride(H*D , _1{}, make_stride(make_stride(D, H_Q*D), H*D*SQ));
|
||||
stride_O = stride_Q;
|
||||
stride_K = make_stride(H_K*D , _1{}, make_stride(make_stride(_0{}, D), H_K*D*SK));
|
||||
stride_V = stride_K;
|
||||
stride_LSE = make_stride(_1{}, make_stride(make_stride(SQ, SQ*H_Q), SQ*H));
|
||||
|
||||
if (kIsVarlen) {
|
||||
get<2,1>(stride_Q) = 0;
|
||||
get<2,1>(stride_K) = 0;
|
||||
get<2,1>(stride_V) = 0;
|
||||
get<2,1>(stride_O) = 0;
|
||||
get<1,1>(stride_LSE) = 0;
|
||||
}
|
||||
|
||||
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||
block_LSE.reset(size(shape_LSE));
|
||||
block_ref_O.reset(size(shape_QO));
|
||||
block_ref_LSE.reset(size(shape_LSE));
|
||||
|
||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
||||
|
||||
if ( ! cumulative_seqlen_q.empty()) {
|
||||
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||
device_cumulative_seqlen_q.copy_from_host(
|
||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||
}
|
||||
if ( ! cumulative_seqlen_kv.empty()) {
|
||||
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||
device_cumulative_seqlen_kv.copy_from_host(
|
||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||
}
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
|
||||
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
|
||||
}
|
||||
|
||||
return problem_shape;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
|
||||
ProblemShapeType problem_shape = initialize(options);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
{ block_Q.get(), stride_Q,
|
||||
block_K.get(), stride_K,
|
||||
block_V.get(), stride_V },
|
||||
{ block_O.get(), stride_O,
|
||||
block_LSE.get(), stride_LSE },
|
||||
hw_info
|
||||
};
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
runtime_ms /= static_cast<float>(options.iterations);
|
||||
|
||||
double flops;
|
||||
if (kIsVarlen) {
|
||||
flops = 0.0;
|
||||
for (int i = 0; i < size<3,1>(problem_shape); i++) {
|
||||
flops += (cumulative_seqlen_q[i+1] - cumulative_seqlen_q[i])
|
||||
* 1.0
|
||||
* (cumulative_seqlen_kv[i+1] - cumulative_seqlen_kv[i]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
flops = 1.0;
|
||||
flops *= static_cast<double>(size<0>(problem_shape));
|
||||
flops *= static_cast<double>(size<1>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,1>(problem_shape));
|
||||
}
|
||||
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
|
||||
flops *= static_cast<double>(size<2>(problem_shape));
|
||||
flops *= static_cast<double>(size<3,0>(problem_shape));
|
||||
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tflops_tc_s = tflops_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tflops_tc_s << " TFLOPS/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms << "ms, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
if (options.varlen) {
|
||||
FwdRunner<true, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
else
|
||||
{
|
||||
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
};
|
||||
|
||||
using HeadDim = _128;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
if (options.varlen) {
|
||||
FwdRunner<true, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
else
|
||||
{
|
||||
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
};
|
||||
|
||||
using HeadDim = _64;
|
||||
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Mask>
|
||||
void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) {
|
||||
auto run = [&](auto shape, const char* name, auto... kernel_options) {
|
||||
if (options.varlen) {
|
||||
FwdRunner<true, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
else {
|
||||
FwdRunner<false, decltype(shape), void, Mask, decltype(kernel_options)...> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
}
|
||||
};
|
||||
|
||||
using HeadDim = _32;
|
||||
|
||||
#ifdef FP8
|
||||
// Persistent Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||
// Individual Tile Scheduler
|
||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major != 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture "
|
||||
<< "(compute capability major 10) and CUDA 12.8 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " Q " << options.q << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "Forward" << " " << (options.causal ? "Causal" : (options.residual ? "Residual" : "None")) << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
auto with_mask = [&](auto fn) {
|
||||
if (options.causal) {
|
||||
fn(CausalMask{});
|
||||
}
|
||||
else if (options.residual) {
|
||||
fn(ResidualMask{});
|
||||
}
|
||||
else {
|
||||
fn(NoMask{});
|
||||
}
|
||||
};
|
||||
|
||||
with_mask([&](auto fusion) {
|
||||
if (options.d <= 32) {
|
||||
run_fwd_32(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 64) {
|
||||
run_fwd_64(fusion, options, hw_info);
|
||||
}
|
||||
else if (options.d <= 128) {
|
||||
run_fwd_128(fusion, options, hw_info);
|
||||
}
|
||||
else {
|
||||
std::cout << "No kernel instantiated for d=" << options.d << std::endl;
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
832
examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu
Normal file
832
examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu
Normal file
@ -0,0 +1,832 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example implementation of fused multi-head attention for the NVIDIA Blackwell SM100
|
||||
architecture using CUTLASS 3.
|
||||
|
||||
MQA/GQA
|
||||
-------
|
||||
|
||||
The head dimension can be represented as a tuple, where the K/V strides in the
|
||||
first dimension is zero. This has the effect of MQA or GQA.
|
||||
* MHA is (head_size:head_stride).
|
||||
* MQA is (head_size:head_stride) in Q and (head_size:_0) in K and V.
|
||||
* GQA is (grouped_heads,heads_kv):(head_stride,grouped_heads*head_stride) in Q
|
||||
and (grouped_heads,heads_kv):(0,head_stride) in K and V
|
||||
|
||||
Example usage:
|
||||
$ ./examples/77_blackell_fmha/77_blackell_fmha_gen_fp8 \
|
||||
--b=2048 --h=2048 --d=2048 --k=2048
|
||||
*/
|
||||
|
||||
#define DSHOW(x) print(#x ": "); print(x); print("\n");
|
||||
#define DSHOWT(x) print(#x ": "); print_tensor(x); print("\n");
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "reference/fmha_fwd_gen_reference.hpp"
|
||||
#include "reference/reference_abs_error.hpp"
|
||||
|
||||
#include "device/fmha.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "collective/sm100_fmha_gen_mainloop_warpspecialized.hpp"
|
||||
#include "collective/sm100_fmha_gen_epilogue_warpspecialized.hpp"
|
||||
#include "kernel/sm100_fmha_gen_kernel_warpspecialized.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class InitStyle {
|
||||
kZero, kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help = false;
|
||||
bool error = false;
|
||||
|
||||
int b = 1;
|
||||
int h = 1;
|
||||
int h_k = 1;
|
||||
int k = 512;
|
||||
int d = 128;
|
||||
int iterations = 3;
|
||||
bool verify = false;
|
||||
bool verbose = false;
|
||||
bool remap = false;
|
||||
bool varlen = false;
|
||||
bool cache_only = false;
|
||||
|
||||
int sm_count = 0;
|
||||
|
||||
std::string kernel_filter;
|
||||
bool clear_cache = false;
|
||||
|
||||
InitStyle init_style_q = InitStyle::kRandom;
|
||||
InitStyle init_style_cache_k = InitStyle::kRandom;
|
||||
InitStyle init_style_cache_v = InitStyle::kRandom;
|
||||
InitStyle init_style_new_k = InitStyle::kRandom;
|
||||
InitStyle init_style_new_v = InitStyle::kRandom;
|
||||
|
||||
static void get_init_style_argument(cutlass::CommandLine& cmd, const char* name, InitStyle& dst, InitStyle const& src) {
|
||||
std::string s;
|
||||
cmd.get_cmd_line_argument(name, s, s);
|
||||
if (s.empty()) {
|
||||
dst = src;
|
||||
}
|
||||
else {
|
||||
if (s == "r") {
|
||||
dst = InitStyle::kRandom;
|
||||
}
|
||||
else if (s == "0") {
|
||||
dst = InitStyle::kZero;
|
||||
}
|
||||
else if (s == "1") {
|
||||
dst = InitStyle::kOne;
|
||||
}
|
||||
else if (s == "d") {
|
||||
dst = InitStyle::kLinearStride1;
|
||||
}
|
||||
else if (s == "s") {
|
||||
dst = InitStyle::kLinearStride128;
|
||||
}
|
||||
else if (s == "n") {
|
||||
dst = InitStyle::kNone;
|
||||
}
|
||||
else {
|
||||
std::cout << "Error: " << s << " is not a valid input type.\n";
|
||||
std::exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
Options defaults;
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("d", d, defaults.d);
|
||||
cmd.get_cmd_line_argument("h", h, -1);
|
||||
if (h == -1) h = 2048 / d;
|
||||
|
||||
cmd.get_cmd_line_argument("h_k", h_k, -1);
|
||||
if (h_k == -1) h_k = h;
|
||||
|
||||
cmd.get_cmd_line_argument("k", k, defaults.k);
|
||||
|
||||
cmd.get_cmd_line_argument("b", b, -1);
|
||||
if (b == -1) b = 16384 / k;
|
||||
if (b == 0) b = 1;
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||
verify = cmd.check_cmd_line_flag("verify");
|
||||
verbose = cmd.check_cmd_line_flag("verbose");
|
||||
varlen = cmd.check_cmd_line_flag("varlen");
|
||||
remap = cmd.check_cmd_line_flag("remap");
|
||||
cache_only = cmd.check_cmd_line_flag("cache-only");
|
||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||
|
||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||
get_init_style_argument(cmd, "init-style", init_style_cache_k, defaults.init_style_cache_k);
|
||||
get_init_style_argument(cmd, "init-style", init_style_cache_v, defaults.init_style_cache_v);
|
||||
get_init_style_argument(cmd, "init-style", init_style_new_k, defaults.init_style_new_k);
|
||||
get_init_style_argument(cmd, "init-style", init_style_new_v, defaults.init_style_new_v);
|
||||
get_init_style_argument(cmd, "init-style-q", init_style_q, init_style_q);
|
||||
get_init_style_argument(cmd, "init-style-cache-k", init_style_cache_k, init_style_cache_k);
|
||||
get_init_style_argument(cmd, "init-style-cache-v", init_style_cache_v, init_style_cache_v);
|
||||
get_init_style_argument(cmd, "init-style-new-k", init_style_new_k, init_style_new_k);
|
||||
get_init_style_argument(cmd, "init-style-new-v", init_style_new_v, init_style_new_v);
|
||||
|
||||
clear_cache = cmd.check_cmd_line_flag("clear-cache");
|
||||
|
||||
cmd.get_cmd_line_argument("kernel-filter", kernel_filter, defaults.kernel_filter);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "77_blackwell_fmha_gen\n\n"
|
||||
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
|
||||
<< " fused multi-head attention forward-pass gen-phase kernels targeting NVIDIA's Blackwell architecture.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
<< " --b=<int> Sets the B extent\n"
|
||||
<< " --h=<int> Sets the H extent\n"
|
||||
<< " --h_k=<int> Sets the H_K/V extent (for GQA/MQA)\n"
|
||||
<< " --k=<int> Sets the K extent (sampled around this length)\n"
|
||||
<< " --d=<int> Sets the D extentn"
|
||||
<< " --iterations=<int> Benchmarking iterations\n"
|
||||
<< " --verify Verify results\n"
|
||||
<< " --verbose Print smem and execution time per kernel\n"
|
||||
<< " --remap Enables batch index remapping\n"
|
||||
<< " --cache-only Only use data from KV cache, no reading or inserting new entry\n"
|
||||
<< " --varlen Varies sequence length between cache entries\n"
|
||||
<< " --sm-count Sets SM count rather than querying it\n"
|
||||
<< " --clear-cache Clears the cache before benchmarking runs\n"
|
||||
<< " --kernel-filter=<filter> Sets regexp to match kernel against\n"
|
||||
<< "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to initialize a block of device data
|
||||
template <class Element>
|
||||
void initialize_block(
|
||||
DeviceAllocation<Element>& block,
|
||||
uint64_t seed=2023, InitStyle init_style = InitStyle::kRandom) {
|
||||
|
||||
switch (init_style) {
|
||||
case InitStyle::kZero: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 0);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kOne: {
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block.get(), block.size(), seed, (Element) 1, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kRandom: {
|
||||
cutlass::reference::device::BlockFillRandomGaussian(
|
||||
block.get(), block.size(), seed, (Element) 0, (Element) 1);
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride1: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (j % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kLinearStride128: {
|
||||
std::vector<Element> data(block.size());
|
||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||
for (int j = 0; j < 128; j++) {
|
||||
data[j + 128*i] = static_cast<Element>((double) (i % 4));
|
||||
}
|
||||
}
|
||||
block.copy_from_host(data.data(), data.size());
|
||||
break;
|
||||
}
|
||||
case InitStyle::kNone: {
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleResult {
|
||||
bool supported = false;
|
||||
bool passed = false;
|
||||
bool verified = false;
|
||||
float runtime_ms = 0;
|
||||
double tflops_tc_s = 0;
|
||||
double tops_exp2_s = 0;
|
||||
double tbytes_s = 0;
|
||||
size_t smem_size = 0;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ClearCache {
|
||||
const int size = 1024 * 1024 * 1024 / 4;
|
||||
DeviceAllocation<float> data;
|
||||
bool active = false;
|
||||
|
||||
ClearCache() = default;
|
||||
|
||||
void set_active(bool the_active) {
|
||||
active = the_active;
|
||||
if (active) {
|
||||
data.reset(size);
|
||||
}
|
||||
else {
|
||||
data.reset(0);
|
||||
}
|
||||
}
|
||||
|
||||
void operator ()() {
|
||||
if (active) {
|
||||
initialize_block(data, 0x49314, InitStyle::kRandom);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class KernelType {
|
||||
UMMA_P, UMMA_I
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<KernelType kKernelType, class TileShape, class ThreadShape>
|
||||
struct ExampleRunner {
|
||||
|
||||
using Element = cutlass::float_e5m2_t;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = cutlass::half_t;
|
||||
|
||||
using ProblemShape = Shape<_1, int, int, Shape<Shape<int, int>, int>>;
|
||||
|
||||
using StrideQ = Stride<_0, _1, Stride<Stride<int, int>, int>>;
|
||||
using StrideNewK = Stride<_0, _1, Stride<Stride<_0, int>, int>>;
|
||||
using StrideCacheK = Stride<int, _1, Stride<Stride<_0, int>, int>>;
|
||||
using StrideNewV = StrideNewK;
|
||||
using StrideCacheV = StrideCacheK;
|
||||
using StrideO = StrideQ;
|
||||
|
||||
using Kernel =
|
||||
cutlass::fmha::kernel::Sm100FmhaGenKernelWarpspecialized<
|
||||
ProblemShape,
|
||||
cutlass::fmha::collective::Sm100FmhaGenMainloopWarpspecialized<
|
||||
Element, ElementAcc, ElementAcc, ElementOut,
|
||||
TileShape,
|
||||
StrideQ, StrideNewK, StrideNewV,
|
||||
StrideCacheK, StrideCacheV, StrideO
|
||||
>,
|
||||
cutlass::fmha::collective::Sm100FmhaGenEpilogueWarpspecialized<ElementOut, StrideO>,
|
||||
std::conditional_t<kKernelType == KernelType::UMMA_P,
|
||||
cutlass::fmha::kernel::PersistentTileScheduler,
|
||||
cutlass::fmha::kernel::IndividualTileScheduler
|
||||
>
|
||||
>;
|
||||
|
||||
using Operation = cutlass::fmha::device::FMHA<Kernel>;
|
||||
|
||||
StrideQ stride_q;
|
||||
StrideNewK stride_new_k;
|
||||
StrideNewV stride_new_v;
|
||||
StrideCacheK stride_cache_k;
|
||||
StrideCacheV stride_cache_v;
|
||||
StrideO stride_o;
|
||||
uint64_t seed = 0;
|
||||
|
||||
std::vector<int> seqlen_kv;
|
||||
|
||||
DeviceAllocation<int> block_seqlen_kv;
|
||||
DeviceAllocation<int> block_cache_batch_idx;
|
||||
DeviceAllocation<Element> block_q;
|
||||
DeviceAllocation<Element> block_new_k;
|
||||
DeviceAllocation<Element> block_new_v;
|
||||
DeviceAllocation<Element> block_cache_k;
|
||||
DeviceAllocation<Element> block_cache_v;
|
||||
DeviceAllocation<ElementOut> block_o;
|
||||
|
||||
DeviceAllocation<Element> block_ref_cache_k;
|
||||
DeviceAllocation<Element> block_ref_cache_v;
|
||||
DeviceAllocation<ElementOut> block_ref_o;
|
||||
|
||||
ClearCache clear_cache;
|
||||
|
||||
bool verify(const ProblemShape& problem_shape) {
|
||||
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(block_q.get()), select<0,2,3>(problem_shape), stride_q);
|
||||
Tensor mNewK = make_tensor(make_gmem_ptr(block_new_k.get()), select<0,2,3>(problem_shape), stride_new_k);
|
||||
Tensor mNewV = make_tensor(make_gmem_ptr(block_new_v.get()), select<0,2,3>(problem_shape), stride_new_v);
|
||||
Tensor mCacheK = make_tensor(make_gmem_ptr(block_ref_cache_k.get()), select<1,2,3>(problem_shape), stride_cache_k);
|
||||
Tensor mCacheV = make_tensor(make_gmem_ptr(block_ref_cache_v.get()), select<1,2,3>(problem_shape), stride_cache_v);
|
||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_o.get()), select<0,2,3>(problem_shape), stride_o);
|
||||
|
||||
fmha_fwd_gen_reference<ElementAcc>(
|
||||
problem_shape, block_seqlen_kv.get(), block_cache_batch_idx.get(),
|
||||
mQ, mNewK, mNewV, mCacheK, mCacheV, mO);
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2;
|
||||
const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3;
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
double max_diff = 0;
|
||||
double mean_diff = 0;
|
||||
reference_abs_diff(block_o, block_ref_o, max_diff, mean_diff);
|
||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if (! passed_O) {
|
||||
std::cerr << "failed O: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_cache_k, block_ref_cache_k, max_diff, mean_diff);
|
||||
bool passed_K = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_K) {
|
||||
std::cerr << "failed Cache K: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
reference_abs_diff(block_cache_v, block_ref_cache_v, max_diff, mean_diff);
|
||||
bool passed_V = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||
if ( ! passed_V) {
|
||||
std::cerr << "failed Cache V: max diff " << max_diff
|
||||
<< " mean " << mean_diff << std::endl;
|
||||
}
|
||||
|
||||
return passed_O && passed_K && passed_V;
|
||||
}
|
||||
|
||||
ProblemShape initialize(const Options& options) {
|
||||
|
||||
clear_cache.set_active(options.clear_cache);
|
||||
|
||||
std::vector<int> cache_batch_idx;
|
||||
|
||||
// set up stides and sizes
|
||||
if (options.remap) {
|
||||
for (int i = 0; i < options.b; i++) {
|
||||
cache_batch_idx.push_back(i);
|
||||
}
|
||||
std::mt19937 rng(0x202305291305ull);
|
||||
std::shuffle(cache_batch_idx.begin(), cache_batch_idx.end(), rng);
|
||||
}
|
||||
|
||||
seqlen_kv = std::vector<int>(options.b, options.k);
|
||||
if (options.varlen) {
|
||||
std::mt19937 rng(0x202305151552ull);
|
||||
std::normal_distribution<double> dist_kv(options.k, options.k / 2);
|
||||
|
||||
auto generate_positive_int = [](auto& dist, auto& gen) {
|
||||
int result = 0;
|
||||
do {
|
||||
result = static_cast<int>(dist(gen));
|
||||
} while (result <= 0);
|
||||
return result;
|
||||
};
|
||||
|
||||
for (int i = 0; i < options.b; i++) {
|
||||
seqlen_kv[i] = generate_positive_int(dist_kv, rng);
|
||||
}
|
||||
}
|
||||
|
||||
int max_seqlen_kv = 0;
|
||||
for (auto e : seqlen_kv) {
|
||||
// if (options.varlen) std::cout << "seqlen " << e << std::endl;
|
||||
max_seqlen_kv = std::max(e, max_seqlen_kv);
|
||||
}
|
||||
|
||||
ProblemShape result = make_shape(_1{}, max_seqlen_kv + 1, options.d, make_shape(make_shape(options.h / options.h_k, options.h_k), options.b));
|
||||
|
||||
stride_q = make_stride(_0{}, _1{}, make_stride(make_stride(options.d, options.d * size<3,0,0>(result)), options.d * size<3,0>(result)));
|
||||
stride_new_k = make_stride(_0{}, _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result)));
|
||||
stride_cache_k = make_stride(options.d * size<3,0,1>(result), _1{}, make_stride(make_stride(_0{}, options.d), options.d * size<3,0,1>(result) * get<1>(result)));
|
||||
|
||||
stride_new_v = stride_new_k;
|
||||
stride_cache_v = stride_cache_k;
|
||||
stride_o = stride_q;
|
||||
|
||||
block_q.reset(options.b * get<2,1>(stride_q));
|
||||
if (! options.cache_only) {
|
||||
block_new_k.reset(options.b * get<2,1>(stride_new_k));
|
||||
block_new_v.reset(options.b * get<2,1>(stride_new_v));
|
||||
}
|
||||
block_cache_k.reset(options.b * get<2,1>(stride_cache_k));
|
||||
block_cache_v.reset(options.b * get<2,1>(stride_cache_v));
|
||||
block_o.reset(options.b * get<2,1>(stride_o));
|
||||
|
||||
block_ref_cache_k.reset(options.b * get<2,1>(stride_cache_k));
|
||||
block_ref_cache_v.reset(options.b * get<2,1>(stride_cache_v));
|
||||
block_ref_o.reset(options.b * get<2,1>(stride_o));
|
||||
|
||||
initialize_block(block_q, seed + 2023, options.init_style_q);
|
||||
if (! options.cache_only) {
|
||||
initialize_block(block_new_k, seed + 2022, options.init_style_new_k);
|
||||
initialize_block(block_new_v, seed + 2021, options.init_style_new_v);
|
||||
}
|
||||
|
||||
initialize_block(block_cache_k, seed + 2024 - 2025, options.init_style_cache_k);
|
||||
initialize_block(block_cache_v, seed + 2025, options.init_style_cache_v);
|
||||
|
||||
block_ref_cache_k.copy_from_device(block_cache_k.get(), block_cache_k.size());
|
||||
block_ref_cache_v.copy_from_device(block_cache_v.get(), block_cache_v.size());
|
||||
block_seqlen_kv.reset(seqlen_kv.size());
|
||||
block_seqlen_kv.copy_from_host(seqlen_kv.data(), seqlen_kv.size());
|
||||
|
||||
if (! cache_batch_idx.empty()) {
|
||||
block_cache_batch_idx.reset(cache_batch_idx.size());
|
||||
block_cache_batch_idx.copy_from_host(cache_batch_idx.data(), cache_batch_idx.size());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||
auto problem_shape = initialize(options);
|
||||
|
||||
typename Operation::Arguments arguments{
|
||||
problem_shape,
|
||||
block_seqlen_kv.get(), block_cache_batch_idx.get(),
|
||||
block_q.get(), stride_q,
|
||||
block_new_k.get(), stride_new_k,
|
||||
block_new_v.get(), stride_new_v,
|
||||
block_cache_k.get(), stride_cache_k,
|
||||
block_cache_v.get(), stride_cache_v,
|
||||
block_o.get(), stride_o,
|
||||
hw_info
|
||||
};
|
||||
|
||||
Operation op;
|
||||
|
||||
ExampleResult example_result;
|
||||
|
||||
example_result.smem_size = Operation::Kernel::SharedStorageSize;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size = Operation::get_workspace_size(arguments);
|
||||
DeviceAllocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
status = op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
// std::cerr << "This kernel is not supported. Last CUDA error is: "
|
||||
// << cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
example_result.supported = true;
|
||||
|
||||
status = op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Run
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result = cudaEventCreate(&event);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
}
|
||||
|
||||
float total_runtime_ms = 0;
|
||||
|
||||
for (int i = 0; i < options.iterations; i++) {
|
||||
|
||||
clear_cache();
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result = cudaEventRecord(events[0]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
status = op.run();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result = cudaEventRecord(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result = cudaEventSynchronize(events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaDeviceSynchronize() failed: " << cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
total_runtime_ms += runtime_ms;
|
||||
|
||||
}
|
||||
|
||||
float runtime_ms = total_runtime_ms / static_cast<float>(options.iterations);
|
||||
|
||||
double bytes;
|
||||
bytes = 0.0;
|
||||
bytes += double(sizeof(Element) * size<3>(problem_shape)); // Q
|
||||
bytes += double(sizeof(ElementOut) * size<3>(problem_shape)); // O
|
||||
bytes += 2.0 * double(sizeof(Element) * size<3>(problem_shape) / size<3,0,0>(problem_shape)); // NewK, NewV
|
||||
double total_seqlen_kv = 0;
|
||||
for (auto e : seqlen_kv) {
|
||||
total_seqlen_kv += double(e + 1);
|
||||
}
|
||||
bytes += 2.0 * double(sizeof(Element) * size<3,0,1>(problem_shape) * total_seqlen_kv); // CacheK, CacheV
|
||||
bytes *= static_cast<double>(size<2>(problem_shape));
|
||||
double tbytes_s = bytes * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);
|
||||
example_result.tbytes_s = tbytes_s;
|
||||
example_result.runtime_ms = runtime_ms;
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
// Verify that the result is correct
|
||||
bool passed = true;
|
||||
if (options.verify) {
|
||||
passed = verify(problem_shape);
|
||||
if (passed) example_result.verified = true;
|
||||
}
|
||||
|
||||
if (!passed) {
|
||||
std::cerr << "Reference check failed" << std::endl;
|
||||
return example_result;
|
||||
}
|
||||
|
||||
example_result.passed = true;
|
||||
|
||||
return example_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to print a description of the example run and its result
|
||||
void print_result(const std::string& description, ExampleResult result, bool verbose) {
|
||||
std::ios fmt(nullptr);
|
||||
fmt.copyfmt(std::cout);
|
||||
std::cout << (result.supported ? (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] ") : "[NSUP] ");
|
||||
std::cout << std::setw(32) << std::left << description;
|
||||
std::cout.copyfmt(fmt);
|
||||
std::cout << " : " << result.tbytes_s << " TB/s" << std::endl;
|
||||
if (verbose) {
|
||||
std::cout << " t=" << result.runtime_ms << "ms, "
|
||||
"smem=" << result.smem_size << "b" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main_single(int argc, char const **args) {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (__CUDACC_VER_MAJOR__ < 12 || props.major < 10) {
|
||||
std::cout
|
||||
<< "This example requires a GPU of NVIDIA's Blackwell Architecture or "
|
||||
<< "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n";
|
||||
return 0;
|
||||
}
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
//
|
||||
// Run examples
|
||||
//
|
||||
|
||||
// The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This
|
||||
// information is used by the underlying kernel.
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
// to use a GPU other than that with device ID 0.
|
||||
hw_info.device_id = 0;
|
||||
if (options.sm_count == 0) {
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
else {
|
||||
hw_info.sm_count = options.sm_count;
|
||||
}
|
||||
|
||||
std::cout << "###### B " << options.b << " H " << options.h << " H_K " << options.h_k << " K " << options.k << " D " << options.d << " ";
|
||||
std::cout << "Gen" << " " << (options.varlen ? "Variable" : "Uniform") << " " << (options.remap ? "Remap" : "Linear") << " ";
|
||||
std::cout << "#SM " << hw_info.sm_count << std::endl;
|
||||
|
||||
using UMMA = true_type;
|
||||
using FFMA2 = false_type;
|
||||
auto run = [&](const char* name, auto kernel_type, auto tile, auto thr) {
|
||||
if ((! options.kernel_filter.empty()) && (! std::regex_search(name, std::basic_regex(options.kernel_filter)))) {
|
||||
return;
|
||||
}
|
||||
ExampleRunner<decltype(kernel_type)::value, decltype(tile), decltype(thr)> runner;
|
||||
auto result = runner.run(options, hw_info);
|
||||
print_result(name, result, options.verbose);
|
||||
};
|
||||
|
||||
|
||||
#define RUN(MODE, m, n, k, tm, tn, tk) \
|
||||
run( \
|
||||
#MODE " " #m "x" #n "x" #k " / " #tm "x" #tn "x" #tk, \
|
||||
std::integral_constant<KernelType, KernelType::MODE>{}, Shape<_##m, _##n, _##k>{}, Shape<_##tm, _##tn, _##tk>{} \
|
||||
)
|
||||
|
||||
RUN(UMMA_I, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_I, 128, 256, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 64, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 128, 128, 1, 1, 1);
|
||||
RUN(UMMA_P, 128, 256, 128, 1, 1, 1);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
std::vector<std::string> full_arguments(args, args + argc);
|
||||
|
||||
int result = 0;
|
||||
|
||||
bool recursed = false;
|
||||
for (size_t i = 1; i < full_arguments.size(); i++) {
|
||||
if (full_arguments[i].find(',') != std::string::npos) {
|
||||
auto arg = full_arguments[i];
|
||||
size_t eq_pos = arg.find('=');
|
||||
std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1);
|
||||
std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1);
|
||||
for (;;) {
|
||||
size_t comma_pos = rest.find(',');
|
||||
std::string current = rest.substr(0, comma_pos);
|
||||
full_arguments[i] = prefix + current;
|
||||
std::vector<const char*> next_args;
|
||||
for (auto& elem : full_arguments) { next_args.push_back(elem.data()); }
|
||||
main(argc, next_args.data());
|
||||
if (comma_pos == std::string::npos) break;
|
||||
rest = rest.substr(comma_pos+1);
|
||||
}
|
||||
recursed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (! recursed) {
|
||||
main_single(argc, args);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
105
examples/77_blackwell_fmha/CMakeLists.txt
Normal file
105
examples/77_blackwell_fmha/CMakeLists.txt
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set_property(
|
||||
SOURCE 77_blackwell_fmha.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v")
|
||||
|
||||
set_property(
|
||||
SOURCE 77_blackwell_fmha_gen.cu
|
||||
PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v")
|
||||
|
||||
set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no)
|
||||
set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal)
|
||||
set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --varlen)
|
||||
set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify)
|
||||
set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify)
|
||||
|
||||
set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify)
|
||||
set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen)
|
||||
set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_GQA --b=2 --h=4 --h_k=2 --k=512 --d=64 --verify)
|
||||
set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
|
||||
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
|
||||
|
||||
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")))
|
||||
if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_fp8
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_fp8 PRIVATE FP8)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_fp8
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_fp8 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_compile_definitions(77_blackwell_fmha_gen_fp8 PRIVATE FP8)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_fp16
|
||||
77_blackwell_fmha.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_BASIC
|
||||
# TEST_CAUSAL
|
||||
# TEST_VARLEN
|
||||
# TEST_HDIM64
|
||||
# TEST_GQA)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
cutlass_example_add_executable(
|
||||
77_blackwell_fmha_gen_fp16
|
||||
77_blackwell_fmha_gen.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_GEN_BASIC
|
||||
# TEST_GEN_VARLEN
|
||||
# TEST_GEN_HDIM64
|
||||
# TEST_GEN_GQA
|
||||
# TEST_GEN_REMAP
|
||||
# TEST_GEN_CACHEONLY)
|
||||
)
|
||||
target_include_directories(77_blackwell_fmha_gen_fp16 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
endif()
|
||||
endif()
|
||||
23
examples/77_blackwell_fmha/README.md
Normal file
23
examples/77_blackwell_fmha/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# FMHA for Blackwell: Forward
|
||||
|
||||
This sample provides code for fused multi-head attention forward, context, or generation phase.
|
||||
It supports HeadDims of 32, 64, and 128, and fp8, fp16, and bf16 input data types.
|
||||
|
||||
For forward or context usage, use an M-blocking (Seqlen-Q) of 256 and an N-blocking (Seqlen-K) of 128.
|
||||
For generation usage, use an M-blocking (Num-Groups) of 128 (although the limit is currently 32 for actual Num-Groups), and a N-blocking (Seqlen-K) of 64, 128 or 256.
|
||||
|
||||
Context loads are done via TMA, whereas generation usage utilized `cp.async` and is thus more amenable to complex load patterns.
|
||||
|
||||
For variable sequence lenght, the code requires a batch of valid (but never used) padding memory ahead of the first input batch. This is achieved with least overhead by leaving one batch free and then arranging QKV consecutively.
|
||||
|
||||
The approach of this implementation is to reuse the selection logic of the collective gemm builder and recombine the result into an FMHA kernel.
|
||||
The kernel and collective layer are then formulated to be fmha-specific.
|
||||
The design assigns two tiles to each threadblock, and pingpongs between them in terms of matrix-matrix multiplication and softmax.
|
||||
|
||||
The example builds four binaries, showcasing the context and generation usage for fp8 and fp16.
|
||||
For detailed information on how to invoke them, check out either the tests in `CMakeLists.txt` or the `--help` for them.
|
||||
|
||||
To modify the code for fusions, `collective/fmha_fusion.hpp` provides the easiest customization point.
|
||||
The `apply_mask` function is called with the accumulator of the first GEMM and the logical positions of those elements.
|
||||
It is well-suited for applying masks or activations.
|
||||
More complex fusions that require memory loads would require modifying the mainloop collective to orchestrate the load via TMA.
|
||||
127
examples/77_blackwell_fmha/collective/fmha_common.hpp
Normal file
127
examples/77_blackwell_fmha/collective/fmha_common.hpp
Normal file
@ -0,0 +1,127 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
constexpr int rA = decltype(rank(tA))::value;
|
||||
constexpr int rB = decltype(rank(tB))::value;
|
||||
constexpr int rC = decltype(rank(tC))::value;
|
||||
static_assert(rA == 3 && rB == 3 && rC == 3);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
|
||||
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
|
||||
atom.accumulate_ = decltype(atom.accumulate_)::One;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Atom, typename TA, typename TB, typename TC>
|
||||
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||
atom.accumulate_ = decltype(atom.accumulate_)::Zero;
|
||||
gemm_reset_zero_acc(atom, tA, tB, tC);
|
||||
}
|
||||
|
||||
template<class Layout, class Stages = _1>
|
||||
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
|
||||
return composition(layout, prepend<decltype(rank(layout))::value>(make_layout(stages), _));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
CUTE_DEVICE T warp_uniform(T a) {
|
||||
return __shfl_sync(0xffffffff, a, 0);
|
||||
}
|
||||
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_tiled_mma_sm100_ts(
|
||||
TiledMMA<MMA_Atom<
|
||||
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
|
||||
cute::C<M>, cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, a_major>,
|
||||
cute::integral_constant<UMMA::Major, b_major>,
|
||||
cute::integral_constant<UMMA::ScaleIn, a_neg>,
|
||||
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
|
||||
TAs...>, TMs...>) {
|
||||
|
||||
return TiledMMA<MMA_Atom<
|
||||
MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major, b_major,
|
||||
a_neg, b_neg, UMMA::Saturate::False>>,
|
||||
TAs...>, TMs...>{};
|
||||
}
|
||||
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_tiled_mma_sm100_ts(
|
||||
TiledMMA<MMA_Atom<
|
||||
SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major,
|
||||
b_major,
|
||||
a_neg,
|
||||
b_neg>,
|
||||
TAs...>, TMs...>) {
|
||||
return TiledMMA<MMA_Atom<
|
||||
SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major, b_major,
|
||||
a_neg, b_neg, UMMA::Saturate::False>,
|
||||
TAs...>, TMs...>{};
|
||||
}
|
||||
|
||||
template<uint32_t RegCount>
|
||||
CUTLASS_DEVICE
|
||||
void warpgroup_reg_set() {
|
||||
if constexpr (RegCount < 128) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<RegCount>();
|
||||
}
|
||||
else {
|
||||
cutlass::arch::warpgroup_reg_alloc<RegCount>();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
254
examples/77_blackwell_fmha/collective/fmha_fusion.hpp
Normal file
254
examples/77_blackwell_fmha/collective/fmha_fusion.hpp
Normal file
@ -0,0 +1,254 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct NoMask {
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return ceil_div(get<1>(problem_size), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
struct ResidualMask : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
template <class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// if the sequence length does not divide the tile size evenly
|
||||
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
||||
}
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
||||
// the remaining elements out from softmax.
|
||||
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
||||
// issues as they are transparently taken care of by TMA and the
|
||||
// epilogue, if it is instantiated with predication support.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if (get<1>(pos) >= get<1>(problem_size)) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CausalMask : NoMask {
|
||||
|
||||
using Base = NoMask;
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// See note below on different ways to think about causal attention
|
||||
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||
return std::min(max_blocks_k, max_blocks_q);
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_masked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||
}
|
||||
|
||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
int get_unmasked_trip_count(
|
||||
BlkCoord const& blk_coord,
|
||||
TileShape const& tile_shape,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
|
||||
}
|
||||
|
||||
template<class AccQK, class IndexQK, class ProblemSize>
|
||||
CUTLASS_DEVICE
|
||||
void apply_mask(
|
||||
AccQK& acc_qk,
|
||||
IndexQK const& index_qk,
|
||||
ProblemSize const& problem_size) {
|
||||
|
||||
// There are two ways to do causal if N_Q != N_K
|
||||
// (1) is to assume that the Q is at the beginning of the matrix
|
||||
// - this is what we demonstrate here
|
||||
// (2) is that it is at the end of the matrix
|
||||
// - this is usually what we want for inference settings
|
||||
// where we only compute the next row and use cache for the rest
|
||||
// - if you'd like this, you only need to add an offset like so:
|
||||
// get<0>(pos) + offset_q < get<1>(pos)
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(acc_qk); i++) {
|
||||
auto pos = index_qk(i);
|
||||
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
|
||||
acc_qk(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct VariableLength {
|
||||
int max_length;
|
||||
int* cumulative_length = nullptr;
|
||||
|
||||
CUTE_HOST_DEVICE operator int() const {
|
||||
return max_length;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T> struct is_variable_length : std::false_type {};
|
||||
template<> struct is_variable_length<VariableLength> : std::true_type {};
|
||||
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
|
||||
|
||||
template<class Shape, class Idx>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length(Shape const& shape, Idx const& idx) {
|
||||
return transform_leaf(shape, [&](auto const& s) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
|
||||
}
|
||||
else {
|
||||
return s;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<class Shape, class Coord, class Idx>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr auto
|
||||
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
|
||||
auto new_shape = apply_variable_length(shape, idx);
|
||||
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
|
||||
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
|
||||
return cute::make_tuple(c, s.cumulative_length[idx]);
|
||||
}
|
||||
else {
|
||||
return c;
|
||||
}
|
||||
});
|
||||
return cute::make_tuple(new_shape, new_coord);
|
||||
}
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
|
||||
namespace cute {
|
||||
|
||||
template<>
|
||||
struct is_integral<cutlass::fmha::collective::VariableLength> : true_type {};
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void print(cutlass::fmha::collective::VariableLength a) {
|
||||
printf("Varlen<%d, %p>", a.max_length, a.cumulative_length);
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,200 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class ElementAcc,
|
||||
class TileShape, // Q, D, _
|
||||
class StrideO, // Q, D, B
|
||||
class StrideLSE // Q, B
|
||||
>
|
||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||
|
||||
using Pipeline = cutlass::PipelineAsync<2>;
|
||||
|
||||
// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{})));
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>());
|
||||
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
using SmemLayoutO = SmemLayoutO_;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
Element* ptr_O;
|
||||
StrideO dO;
|
||||
|
||||
ElementAcc* ptr_LSE;
|
||||
StrideLSE dLSE;
|
||||
};
|
||||
|
||||
using TMA_O = decltype(make_tma_copy(
|
||||
SM90_TMA_STORE{},
|
||||
make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}),
|
||||
SmemLayoutO{}(_,_,_0{})
|
||||
));
|
||||
|
||||
|
||||
struct Params {
|
||||
TMA_O tma_store_o;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace = nullptr) {
|
||||
|
||||
auto ptr_O = args.ptr_O;
|
||||
StrideO dO = args.dO;
|
||||
auto problem_shape_O = select<0,2,3>(problem_shape);
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dO) = get<0>(dO);
|
||||
get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_O -= max_length_q * get<0>(dO);
|
||||
}
|
||||
}
|
||||
|
||||
auto tma_store_o = make_tma_copy(
|
||||
SM90_TMA_STORE{},
|
||||
make_tensor(ptr_O, problem_shape_O, dO),
|
||||
SmemLayoutO{}(_,_,_0{})
|
||||
);
|
||||
|
||||
return {
|
||||
tma_store_o
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
store(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& shared_storage,
|
||||
Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {
|
||||
|
||||
BlkCoord blk_coord = blk_coord_in;
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
int o0_index = 2 * get<0>(blk_coord);
|
||||
int o1_index = 2 * get<0>(blk_coord) + 1;
|
||||
|
||||
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
// offset mode 0 by (max_length - real_length)
|
||||
// offset mode 3,1 by cumulative_length + real_length
|
||||
// the ptr is already offset by - max_length
|
||||
// so in total this achieves
|
||||
int offs_0 = 0;
|
||||
int offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
offs_0 = max_length_q - get<0>(problem_shape);
|
||||
offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape);
|
||||
get<2,1>(blk_coord) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p);
|
||||
|
||||
Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{});
|
||||
Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord));
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto block_tma = params.tma_store_o.get_slice(0);
|
||||
Tensor tOsO = block_tma.partition_S(sO);
|
||||
Tensor tOgO = block_tma.partition_D(gO);
|
||||
|
||||
auto pipeline_release_state = pipeline_consumer_state;
|
||||
|
||||
// O1 O2
|
||||
// one pipeline: O
|
||||
// wait from corr, issue tma store on smem
|
||||
pipeline.consumer_wait(pipeline_consumer_state);
|
||||
++pipeline_consumer_state;
|
||||
|
||||
if (lane_predicate) {
|
||||
copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index));
|
||||
}
|
||||
tma_store_arrive();
|
||||
|
||||
pipeline.consumer_wait(pipeline_consumer_state);
|
||||
++pipeline_consumer_state;
|
||||
|
||||
if (lane_predicate) {
|
||||
copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index));
|
||||
}
|
||||
tma_store_arrive();
|
||||
|
||||
tma_store_wait<1>();
|
||||
|
||||
pipeline.consumer_release(pipeline_release_state);
|
||||
++pipeline_release_state;
|
||||
|
||||
tma_store_wait<0>();
|
||||
|
||||
pipeline.consumer_release(pipeline_release_state);
|
||||
++pipeline_release_state;
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,94 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
template<
|
||||
class Element_,
|
||||
class StrideO_
|
||||
>
|
||||
struct Sm100FmhaGenEpilogueWarpspecialized {
|
||||
|
||||
using Pipeline = cutlass::PipelineAsync<2>;
|
||||
|
||||
using SmemLayoutO = Layout<Shape<_1, _1, _1>>;
|
||||
using SmemLayoutO_ = SmemLayoutO;
|
||||
using Element = Element_;
|
||||
using StrideOOrig = StrideO_;
|
||||
using StrideO = decltype(replace<0>(StrideOOrig{}, 0));
|
||||
|
||||
struct TensorStorage {
|
||||
|
||||
using SmemLayoutO = SmemLayoutO_;
|
||||
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
Element* ptr_o;
|
||||
StrideO dO;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
const Params& params;
|
||||
|
||||
CUTLASS_DEVICE Sm100FmhaGenEpilogueWarpspecialized(const Params& params) : params(params) {}
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace = nullptr) {
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
/* no-op */
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
store(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& shared_storage,
|
||||
Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {
|
||||
/* no-op */
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,395 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
#include "collective/fmha_common.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class StrideQ,
|
||||
class StrideNewK,
|
||||
class StrideNewV,
|
||||
class StrideCacheK,
|
||||
class StrideCacheV,
|
||||
class TensorStorage,
|
||||
class CollectiveMmaQK,
|
||||
class CollectiveMmaPV,
|
||||
class SmemLayoutQ,
|
||||
class SmemLayoutK,
|
||||
class SmemLayoutV,
|
||||
class PipelineQ,
|
||||
class PipelineKV,
|
||||
class TileShape,
|
||||
class Mask
|
||||
>
|
||||
struct Sm100FmhaLoadCpAsyncWarpspecialized {
|
||||
|
||||
using TileShapeQK = typename CollectiveMmaQK::TileShape;
|
||||
using TileShapePV = typename CollectiveMmaPV::TileShape;
|
||||
|
||||
struct Arguments {
|
||||
|
||||
const int* cache_batch_idx;
|
||||
|
||||
const Element* ptr_q;
|
||||
StrideQ dQ;
|
||||
|
||||
const Element* ptr_new_k;
|
||||
StrideNewK dNewK;
|
||||
const Element* ptr_new_v;
|
||||
StrideNewV dNewV;
|
||||
|
||||
Element* ptr_cache_k;
|
||||
StrideCacheK dCacheK;
|
||||
Element* ptr_cache_v;
|
||||
StrideCacheV dCacheV;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
}
|
||||
|
||||
template<class TEngine, class TLayout>
|
||||
CUTLASS_DEVICE auto constexpr transpose(Tensor<TEngine, TLayout> const& t) {
|
||||
CUTE_STATIC_ASSERT_V(rank(t) == _2{});
|
||||
return t.compose(make_layout(make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{})));
|
||||
}
|
||||
|
||||
template<
|
||||
class CAtom, class TA, class TB,
|
||||
class CountTensor, class CountLimit,
|
||||
class SrcTensor, class DstTensor
|
||||
>
|
||||
CUTLASS_DEVICE void copy_with_limit(
|
||||
TiledCopy<CAtom, TA, TB> const& tiled_copy,
|
||||
CountTensor const& c, CountLimit const& l,
|
||||
SrcTensor const& src, DstTensor&& dst) {
|
||||
|
||||
//copy(tiled_copy, src, dst);
|
||||
#if 1
|
||||
auto c_f = make_tensor(c.data(), flatten(c.layout()));
|
||||
auto src_f = make_tensor(src.data(), flatten(src.layout()));
|
||||
auto dst_f = make_tensor(dst.data(), flatten(dst.layout()));
|
||||
auto c_v = group_modes<1,rank_v<decltype(c_f)>>(c_f);
|
||||
auto src_v = group_modes<1,rank_v<decltype(src_f)>>(src_f);
|
||||
auto dst_v = group_modes<1,rank_v<decltype(dst_f)>>(dst_f);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<1>(src_v); i++) {
|
||||
if (elem_less(c_v(_0{}, i), l)) {
|
||||
copy(CAtom{}, src_v(_, i), dst_v(_, i));
|
||||
}
|
||||
else {
|
||||
clear(dst_v(_, i));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
BlkCoord const& blk_coord, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& storage,
|
||||
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
|
||||
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
||||
mask_tile_count *= 2;
|
||||
|
||||
int warp_idx = (threadIdx.x / 32) % 2;
|
||||
int thread_idx = warp_idx * 32 + (threadIdx.x % 32);
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
// this one is only executed by one thread, no need to elect_one
|
||||
auto blk_coord_cache = blk_coord;
|
||||
if (params.cache_batch_idx != nullptr) {
|
||||
get<2,1>(blk_coord_cache) = params.cache_batch_idx[get<2,1>(blk_coord_cache)];
|
||||
}
|
||||
|
||||
// Q1, K1, K2, V1, K3, V2, ... Kn, Vn-1, Vn
|
||||
// two pipes: Q and KV
|
||||
auto cQ = make_identity_tensor(select<0,2>(TileShape{}));
|
||||
auto mQ = make_tensor(make_gmem_ptr(params.ptr_q), append<3>(select<0,2>(TileShapeQK{}), get<3>(problem_shape)), params.dQ);
|
||||
auto gQ = mQ(_, _, get<2>(blk_coord));
|
||||
auto sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
|
||||
typename CollectiveMmaQK::TiledMma mma_qk;
|
||||
ThrMMA thr_mma_qk = mma_qk.get_slice(0);
|
||||
auto tSgQ = thr_mma_qk.partition_A(gQ);
|
||||
auto tScQ = thr_mma_qk.partition_A(cQ);
|
||||
|
||||
auto atom_q_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _16>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
auto atom_kv_tv = Layout<Shape<Shape<_2, _32>, Shape<_16, _4>>, Stride<Stride<_16, _32>, Stride<_1, _1024>>>{};
|
||||
|
||||
auto tiled_copy_q = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},
|
||||
atom_q_tv,
|
||||
make_layout(shape(tSgQ), replace<0>(stride(tSgQ), replace<0>(stride<0>(tSgQ), get<2>(TileShape{})))));
|
||||
|
||||
auto thr_copy_q = tiled_copy_q.get_slice(thread_idx);
|
||||
|
||||
auto tQsQ = thr_copy_q.partition_D(sQ);
|
||||
auto tQgQ = thr_copy_q.partition_S(tSgQ);
|
||||
auto tQcQ = thr_copy_q.partition_S(tScQ);
|
||||
|
||||
auto limitQ = append<2>(get<0>(problem_shape), _128{});
|
||||
|
||||
// Q1
|
||||
int q0_index = get<0>(blk_coord);
|
||||
// pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
|
||||
// copy_with_limit(tiled_copy_q, tQcQ, limitQ, tQgQ, tQsQ(_, _, _, _, pipeline_q_producer_state.index());
|
||||
auto load_q = [&](int q_index, auto& state) {
|
||||
pipeline_q.producer_acquire(state);
|
||||
|
||||
// using Vec = Element;
|
||||
// auto vzero = Element(0);
|
||||
// q is always loaded masked
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
//auto src = recast<Vec>(tQgQ(_, _, _, _, q_index));
|
||||
auto src = recast<Vec>(tQgQ(_, _, _, _));
|
||||
auto dst = recast<Vec>(tQsQ(_, _, _, _, state.index()));
|
||||
// auto c = tQcQ(_, _, _, _, q_index);
|
||||
auto c = tQcQ(_, _, _, _);
|
||||
int vlen = sizeof(Vec) / sizeof(Element);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src); i++) {
|
||||
auto cc = c(vlen*i);
|
||||
Vec* dst_ptr = &dst(i);
|
||||
const Vec* src_ptr = &src(i);
|
||||
bool guard = elem_less(cc, limitQ);
|
||||
cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Always>(
|
||||
dst_ptr, src_ptr, guard
|
||||
);
|
||||
}
|
||||
|
||||
pipeline_q.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
};
|
||||
|
||||
load_q(q0_index, pipeline_q_producer_state);
|
||||
// pipeline_q.producer_commit(pipeline_q_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{}));
|
||||
auto cK = make_tensor(cK_t.data(), make_layout(get<0>(cK_t.layout()), get<1>(cK_t.layout()), make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t))));
|
||||
auto mK = make_tensor(make_gmem_ptr(params.ptr_cache_k), select<1,2,3>(problem_shape), params.dCacheK);
|
||||
auto gK = local_tile(mK(_, _, get<2>(blk_coord_cache)), TileShapeQK{}, make_coord(_, _, _0{}), Step<X, _1, _1>{});
|
||||
auto sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
|
||||
auto tSgK = thr_mma_qk.partition_B(gK);
|
||||
auto tScK = thr_mma_qk.partition_B(cK);
|
||||
|
||||
auto tSlK = thr_mma_qk.partition_B(make_tensor((Element*) nullptr, make_ordered_layout(select<1,2>(TileShapeQK{}), Step<_1, _0>{})));
|
||||
auto tiled_copy_k = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, Element>{},
|
||||
atom_kv_tv,
|
||||
tSlK.layout());
|
||||
|
||||
auto thr_copy_k = tiled_copy_k.get_slice(thread_idx);
|
||||
|
||||
auto tKsK = thr_copy_k.partition_D(sK);
|
||||
auto tKgK = thr_copy_k.partition_S(tSgK);
|
||||
auto tKcK = thr_copy_k.partition_S(tScK);
|
||||
|
||||
int seqlen_cache_kv = get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0);
|
||||
auto limitK = append<2>(seqlen_cache_kv, _128{});
|
||||
|
||||
auto cV_t = make_identity_tensor(select<1,2>(TileShapePV{}));
|
||||
auto cV = make_tensor(cV_t.data(), make_layout(get<0>(cV_t.layout()), get<1>(cV_t.layout()), make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t))));
|
||||
auto mV = make_tensor(make_gmem_ptr(params.ptr_cache_v), select<2,1,3>(problem_shape), select<1,0,2>(params.dCacheV));
|
||||
auto gV = local_tile(mV(_, _, get<2>(blk_coord_cache)), TileShapePV{}, make_coord(_, _0{}, _), Step<X, _1, _1>{});
|
||||
auto sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
typename CollectiveMmaPV::TiledMma mma_pv;
|
||||
ThrMMA thr_mma_pv = mma_pv.get_slice(0);
|
||||
auto tOgV = thr_mma_pv.partition_B(gV);
|
||||
auto tOcV = thr_mma_pv.partition_B(cV);
|
||||
auto tOlV = thr_mma_pv.partition_B(make_tensor((Element*) nullptr, make_layout(select<1,2>(TileShapePV{}))));
|
||||
|
||||
auto tiled_copy_v = make_cotiled_copy(
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, Element>{},
|
||||
atom_kv_tv,
|
||||
tOlV.layout());
|
||||
|
||||
auto thr_copy_v = tiled_copy_v.get_slice(thread_idx);
|
||||
|
||||
auto tVsV = thr_copy_v.partition_D(sV);
|
||||
auto tVgV = thr_copy_v.partition_S(tOgV);
|
||||
auto tVcV = thr_copy_v.partition_S(tOcV);
|
||||
|
||||
auto limitV = select<1,0>(limitK);
|
||||
|
||||
int full_tiles_cache = seqlen_cache_kv / get<1>(TileShapeQK{});
|
||||
|
||||
bool has_new = params.ptr_new_k != nullptr;
|
||||
Tensor mNewK = make_tensor(make_gmem_ptr(params.ptr_new_k), select<1,2,3>(problem_shape), params.dNewK);
|
||||
Tensor mNewV = make_tensor(make_gmem_ptr(params.ptr_new_v), select<1,2,3>(problem_shape), params.dNewV);
|
||||
Tensor gNewK = mNewK(_, _, get<2>(blk_coord));
|
||||
Tensor gNewV = mNewV(_, _, get<2>(blk_coord));
|
||||
|
||||
auto load_k = [&](int k_index, auto& state) {
|
||||
pipeline_kv.producer_acquire(state);
|
||||
|
||||
if (k_index < full_tiles_cache) {
|
||||
copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index()));
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
} else {
|
||||
// using Vec = Element;
|
||||
// auto vzero = Element(0);
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
auto src = recast<Vec>(tKgK(_, _, _, _, k_index));
|
||||
auto dst = recast<Vec>(tKsK(_, _, _, _, state.index()));
|
||||
auto src2 = recast<Vec>(gNewK);
|
||||
auto c = tKcK(_, _, _, _, k_index);
|
||||
int vlen = sizeof(Vec) / sizeof(Element);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src); i++) {
|
||||
auto cc = c(vlen*i);
|
||||
Vec* dst_ptr = &dst(i);
|
||||
const Vec* src_ptr = &src(i);
|
||||
bool guard = elem_less(cc, limitK);
|
||||
if (get<0>(cc) == seqlen_cache_kv && has_new) {
|
||||
src_ptr = &src2(_0{}, get<1>(cc) / vlen);
|
||||
guard = true;
|
||||
}
|
||||
cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>(
|
||||
dst_ptr, src_ptr, guard
|
||||
);
|
||||
}
|
||||
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
}
|
||||
};
|
||||
|
||||
auto load_v = [&](int v_index, auto& state) {
|
||||
pipeline_kv.producer_acquire(state);
|
||||
|
||||
if (v_index < full_tiles_cache) {
|
||||
copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index()));
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
} else {
|
||||
// using Vec = Element;
|
||||
// auto vzero = Element(0);
|
||||
using Vec = uint128_t;
|
||||
Vec vzero = uint128_t(0, 0);
|
||||
auto src = recast<Vec>(tVgV(_, _, _, _, v_index));
|
||||
auto dst = recast<Vec>(tVsV(_, _, _, _, state.index()));
|
||||
auto src2 = recast<Vec>(gNewV);
|
||||
int vlen = sizeof(Vec) / sizeof(Element);
|
||||
auto c = tVcV(_, _, _, _, v_index);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src); i++) {
|
||||
auto cc = c(vlen*i);
|
||||
Vec* dst_ptr = &dst(i);
|
||||
const Vec* src_ptr = &src(i);
|
||||
bool guard = elem_less(cc, limitV);
|
||||
if (get<1>(cc) == seqlen_cache_kv && has_new) {
|
||||
src_ptr = &src2(_0{}, get<0>(cc) / vlen);
|
||||
guard = true;
|
||||
}
|
||||
cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>(
|
||||
dst_ptr, src_ptr, guard
|
||||
);
|
||||
}
|
||||
|
||||
pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
|
||||
}
|
||||
};
|
||||
|
||||
// K1
|
||||
int k_index = 0;
|
||||
int v_index = 0;
|
||||
|
||||
load_k(k_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
mask_tile_count -= 1;
|
||||
|
||||
for (; mask_tile_count > 0; mask_tile_count -= 1) {
|
||||
|
||||
load_k(k_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
load_v(v_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
v_index += 1;
|
||||
}
|
||||
|
||||
// V1
|
||||
|
||||
load_v(v_index, pipeline_kv_producer_state);
|
||||
|
||||
++pipeline_kv_producer_state;
|
||||
v_index += 1;
|
||||
|
||||
if (has_new) {
|
||||
for (int i = thread_idx; i < get<2>(TileShape{}); i += 64) {
|
||||
gK(seqlen_cache_kv, i, 0) = gNewK(0, i);
|
||||
gV(i, seqlen_cache_kv, 0) = gNewV(0, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
@ -0,0 +1,316 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
#include "collective/fmha_common.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::collective {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<
|
||||
class Element,
|
||||
class StrideQ,
|
||||
class StrideK,
|
||||
class StrideV,
|
||||
class CollectiveMmaQK,
|
||||
class CollectiveMmaPV,
|
||||
class SmemLayoutQ,
|
||||
class SmemLayoutK,
|
||||
class SmemLayoutV,
|
||||
class TensorStorage,
|
||||
class PipelineQ,
|
||||
class PipelineKV,
|
||||
class Mask,
|
||||
class TileShape
|
||||
>
|
||||
struct Sm100FmhaLoadTmaWarpspecialized {
|
||||
|
||||
using TileShapeQK = typename CollectiveMmaQK::TileShape;
|
||||
using TileShapePV = typename CollectiveMmaPV::TileShape;
|
||||
|
||||
struct Arguments {
|
||||
const Element* ptr_Q;
|
||||
StrideQ dQ;
|
||||
const Element* ptr_K;
|
||||
StrideK dK;
|
||||
const Element* ptr_V;
|
||||
StrideV dV;
|
||||
};
|
||||
|
||||
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||
|
||||
struct Params {
|
||||
TMA_Q tma_load_q;
|
||||
TMA_K tma_load_k;
|
||||
TMA_V tma_load_v;
|
||||
};
|
||||
|
||||
template<class ProblemShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
|
||||
auto ptr_Q = args.ptr_Q;
|
||||
auto ptr_K = args.ptr_K;
|
||||
auto ptr_V = args.ptr_V;
|
||||
auto dQ = args.dQ;
|
||||
auto dK = args.dK;
|
||||
auto dV = args.dV;
|
||||
auto problem_shape_qk = problem_shape;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dQ) = get<0>(dQ);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_Q -= max_length_q * get<0>(dQ);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
|
||||
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
|
||||
if (cumulative_length_kv != nullptr) {
|
||||
int max_length_kv = get<1>(problem_shape).max_length;
|
||||
// for variable sequence lenght, the batch is in units of row_stride
|
||||
get<2,1>(dK) = get<0>(dK);
|
||||
get<2,1>(dV) = get<0>(dV);
|
||||
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
|
||||
// offset ptr by the amount we add back in later
|
||||
ptr_K -= max_length_kv * get<0>(dK);
|
||||
ptr_V -= max_length_kv * get<0>(dV);
|
||||
}
|
||||
}
|
||||
|
||||
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
|
||||
problem_shape_qk,
|
||||
typename CollectiveMmaQK::Arguments {
|
||||
ptr_Q, dQ,
|
||||
ptr_K, dK,
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
|
||||
problem_shape_pv,
|
||||
typename CollectiveMmaPV::Arguments {
|
||||
ptr_K, dK, // never used, dummy
|
||||
ptr_V, select<1,0,2>(dV),
|
||||
}, /*workspace=*/ nullptr);
|
||||
|
||||
return Params{
|
||||
params_qk.tma_load_a,
|
||||
params_qk.tma_load_b,
|
||||
params_pv.tma_load_b
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& params) {
|
||||
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
|
||||
Params const& params, ParamsProblemShape const& params_problem_shape,
|
||||
TensorStorage& storage,
|
||||
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
|
||||
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
|
||||
|
||||
BlkCoord blk_coord_q = blk_coord_in;
|
||||
BlkCoord blk_coord_kv = blk_coord_in;
|
||||
|
||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
|
||||
|
||||
using X = Underscore;
|
||||
|
||||
// this one is only executed by one thread, no need to elect_one
|
||||
|
||||
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
|
||||
// two pipes: Q and KV
|
||||
// from Memory (prod) to TensorCore (cons)
|
||||
|
||||
// compute gQ, sQ
|
||||
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
|
||||
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
|
||||
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
|
||||
|
||||
int q_offs_0 = 0;
|
||||
int q_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
|
||||
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length_q != nullptr) {
|
||||
int max_length_q = get<0>(params_problem_shape).max_length;
|
||||
q_offs_0 = max_length_q - get<0>(problem_shape);
|
||||
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
|
||||
get<2,1>(blk_coord_q) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
|
||||
|
||||
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
|
||||
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||
auto [tQgQ_qdl, tQsQ] = tma_partition(
|
||||
params.tma_load_q, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
|
||||
);
|
||||
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
|
||||
|
||||
// compute gK, sK
|
||||
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
|
||||
|
||||
int kv_offs_0 = 0;
|
||||
int kv_offs_2_1 = 0;
|
||||
|
||||
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
|
||||
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
|
||||
if (cumulative_length != nullptr) {
|
||||
int max_length = get<1>(params_problem_shape).max_length;
|
||||
kv_offs_0 = max_length - get<1>(problem_shape);
|
||||
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
|
||||
get<2,1>(blk_coord_kv) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
|
||||
|
||||
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
|
||||
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||
auto [tKgK_kdl, tKsK] = tma_partition(
|
||||
params.tma_load_k, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
|
||||
);
|
||||
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
|
||||
|
||||
// compute gV, sV
|
||||
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
|
||||
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
|
||||
|
||||
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
|
||||
|
||||
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
|
||||
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||
auto [tVgV_dkl, tVsV] = tma_partition(
|
||||
params.tma_load_v, _0{}, make_layout(_1{}),
|
||||
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
|
||||
);
|
||||
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
|
||||
|
||||
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
|
||||
// As such, it needs to be transformed as
|
||||
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
|
||||
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
|
||||
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Q1
|
||||
int q0_index = 2 * get<0>(blk_coord_q);
|
||||
int q1_index = 2 * get<0>(blk_coord_q) + 1;
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
// K1
|
||||
int k_index = 0;
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Q2
|
||||
pipeline_q.producer_acquire(pipeline_q_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
|
||||
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
|
||||
}
|
||||
++pipeline_q_producer_state;
|
||||
|
||||
// V1
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
|
||||
// loop:
|
||||
mask_tile_count -= 1;
|
||||
for (; mask_tile_count > 0; mask_tile_count -= 1) {
|
||||
|
||||
// Ki
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
|
||||
// Vi
|
||||
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
|
||||
if (lane_predicate) {
|
||||
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
|
||||
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
|
||||
}
|
||||
++pipeline_kv_producer_state;
|
||||
k_index += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::collective
|
||||
276
examples/77_blackwell_fmha/device/fmha.hpp
Normal file
276
examples/77_blackwell_fmha/device/fmha.hpp
Normal file
@ -0,0 +1,276 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class Kernel_>
|
||||
class FMHA {
|
||||
public:
|
||||
using Kernel = Kernel_;
|
||||
|
||||
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
|
||||
|
||||
/// Argument structure: User API
|
||||
using Arguments = typename Kernel::Arguments;
|
||||
/// Argument structure: Kernel API
|
||||
using Params = typename Kernel::Params;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (Kernel::can_implement(args)) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
return Kernel::get_grid_shape(params);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FMHA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("FMHA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = get_grid_shape(params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
85
examples/77_blackwell_fmha/kernel/fmha_options.hpp
Normal file
85
examples/77_blackwell_fmha/kernel/fmha_options.hpp
Normal file
@ -0,0 +1,85 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
struct find_option;
|
||||
|
||||
template<auto kTag, typename Default>
|
||||
struct find_option<kTag, Default> {
|
||||
using option_value = Default;
|
||||
};
|
||||
|
||||
template<auto kTag, typename Default, typename Option, typename... Options>
|
||||
struct find_option<kTag, Default, Option, Options...> :
|
||||
std::conditional_t<
|
||||
Option::tag == kTag,
|
||||
Option,
|
||||
find_option<kTag, Default, Options...>
|
||||
>
|
||||
{};
|
||||
|
||||
template<auto kTag, typename Default, typename... Options>
|
||||
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
|
||||
|
||||
enum class Tag {
|
||||
kIsPersistent,
|
||||
kNumMmaWarpGroups,
|
||||
kLoadsQSeparately,
|
||||
|
||||
kIsMainloopLocked,
|
||||
kIsEpilogueLocked,
|
||||
|
||||
kStagesQ,
|
||||
kStagesKV,
|
||||
|
||||
kEpilogueKind,
|
||||
|
||||
kBlocksPerSM,
|
||||
kClusterM,
|
||||
|
||||
kAccQK
|
||||
};
|
||||
|
||||
template<auto kTag, class Value>
|
||||
struct Option {
|
||||
static constexpr auto tag = kTag;
|
||||
using option_value = Value;
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
162
examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
162
examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
@ -0,0 +1,162 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct IndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size));
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
IndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct PersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_h;
|
||||
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, TileShape const& tile_shape) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
|
||||
int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size);
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, bidh;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_h(block_decode, bidh, block_decode);
|
||||
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
PersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -0,0 +1,519 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/arch/tmem_allocator_sm100.hpp"
|
||||
|
||||
#include "kernel/fmha_options.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
#include "collective/fmha_common.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
struct Sm100FmhaCtxKernelWarpspecializedSchedule {
|
||||
|
||||
enum class WarpRole {
|
||||
Softmax0,
|
||||
Softmax1,
|
||||
Correction,
|
||||
MMA,
|
||||
Load,
|
||||
Epilogue,
|
||||
Empty
|
||||
};
|
||||
|
||||
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
int wg_idx = warp_idx / 4; // warp_idx
|
||||
if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3
|
||||
if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7
|
||||
if (wg_idx == 2) return WarpRole::Correction; // 8 - 11
|
||||
if (warp_idx == 12) return WarpRole::MMA; // 12
|
||||
if (warp_idx == 13) return WarpRole::Load; // 13
|
||||
if (warp_idx == 14) return WarpRole::Epilogue; // 14
|
||||
return WarpRole::Empty; // 15
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = 4;
|
||||
static const int NumWarpsCorrection = 4;
|
||||
static const int NumWarpsEpilogue = 1;
|
||||
static const int NumWarpsLoad = 1;
|
||||
|
||||
static const bool kDebugUsingPrintf = false;
|
||||
static const int NumRegsSoftmax = 192;
|
||||
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = 16;
|
||||
|
||||
};
|
||||
|
||||
template<
|
||||
class ProblemShapeIn,
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class TileScheduler,
|
||||
class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule
|
||||
>
|
||||
struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ProblemShape = ProblemShapeIn;
|
||||
|
||||
using WarpRole = typename KernelSchedule::WarpRole;
|
||||
|
||||
constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
return KernelSchedule::warp_idx_to_WarpRole(warp_idx);
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax;
|
||||
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
|
||||
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
|
||||
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
|
||||
|
||||
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
|
||||
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
|
||||
static const int NumRegsOther = KernelSchedule::NumRegsOther;
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = KernelSchedule::NumWarps;
|
||||
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using TmemAllocator = cute::TMEM::Allocator1Sm;
|
||||
|
||||
struct SharedStorage {
|
||||
typename CollectiveMainloop::TensorStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
|
||||
struct PipelineStorage {
|
||||
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
|
||||
alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi;
|
||||
alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01;
|
||||
} pipelines;
|
||||
|
||||
uint32_t tmem_base_ptr;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
struct Arguments {
|
||||
ProblemShape problem_shape;
|
||||
typename CollectiveMainloop::Arguments mainloop;
|
||||
typename CollectiveEpilogue::Arguments epilogue;
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_shape;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return Params{
|
||||
args.problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
|
||||
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) {
|
||||
return apply_variable_length(params.problem_shape, batch_idx);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
auto role = warp_idx_to_WarpRole(warp_idx);
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (role == WarpRole::Load && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
if (role == WarpRole::Epilogue && lane_predicate) {
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load);
|
||||
pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ;
|
||||
typename CollectiveMainloop::PipelineQ pipeline_load_q(
|
||||
shared_storage.pipelines.load_q,
|
||||
pipeline_load_q_params,
|
||||
ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load);
|
||||
pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadKV;
|
||||
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
|
||||
shared_storage.pipelines.load_kv,
|
||||
pipeline_load_kv_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s0(
|
||||
shared_storage.pipelines.mma_s0,
|
||||
pipeline_mma_s0_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s1(
|
||||
shared_storage.pipelines.mma_s1,
|
||||
pipeline_mma_s1_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params;
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s0_corr(
|
||||
shared_storage.pipelines.s0_corr,
|
||||
pipeline_s0_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params;
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s1_corr(
|
||||
shared_storage.pipelines.s1_corr,
|
||||
pipeline_s1_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineO pipeline_mma_corr(
|
||||
shared_storage.pipelines.mma_corr,
|
||||
pipeline_mma_corr_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params;
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Epilogue) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
|
||||
shared_storage.pipelines.corr_epi,
|
||||
pipeline_corr_epi_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01;
|
||||
params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0;
|
||||
params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax order_s01(
|
||||
shared_storage.pipelines.order_s01, params_order_s01);
|
||||
|
||||
TmemAllocator tmem_allocator;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
pipeline_load_q.init_masks(ClusterShape{});
|
||||
pipeline_load_kv.init_masks(ClusterShape{});
|
||||
pipeline_mma_s0.init_masks(ClusterShape{});
|
||||
pipeline_mma_s1.init_masks(ClusterShape{});
|
||||
pipeline_mma_corr.init_masks(ClusterShape{});
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state;
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineQ>();
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state;
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineKV>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state;
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineE>();
|
||||
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||
|
||||
CollectiveMainloop mainloop;
|
||||
CollectiveEpilogue epilogue;
|
||||
|
||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||
warpgroup_reg_set<NumRegsSoftmax>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_softmax_0 = role == WarpRole::Softmax0;
|
||||
|
||||
mainloop.softmax(
|
||||
is_softmax_0 ? 0 : 1, blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1,
|
||||
is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state,
|
||||
is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr,
|
||||
is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state,
|
||||
order_s01
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Correction) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.correction(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
if constexpr (NumWarpsEpilogue == 0) {
|
||||
static_assert(NumWarpsCorrection == 1);
|
||||
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::MMA) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
mainloop.mma(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_consumer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_consumer_state,
|
||||
pipeline_mma_s0, pipeline_mma_s0_producer_state,
|
||||
pipeline_mma_s1, pipeline_mma_s1_producer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.load(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.mainloop, params.problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_producer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_producer_state
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Epilogue) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
epilogue.store(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.epilogue, params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_corr_epi, pipeline_corr_epi_consumer_state
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
static_assert(NumWarpsEpilogue <= 1);
|
||||
if constexpr (NumWarpsEpilogue == 1) {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::Empty) {
|
||||
warpgroup_reg_set<NumRegsEmpty>();
|
||||
|
||||
/* no-op, donate regs and exit */
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
@ -0,0 +1,576 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cute/arch/tmem_allocator_sm100.hpp"
|
||||
|
||||
#include "kernel/fmha_options.hpp"
|
||||
#include "kernel/fmha_tile_scheduler.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
struct Sm100FmhaGenKernelWarpspecializedSchedule {
|
||||
|
||||
enum class WarpRole {
|
||||
Softmax0,
|
||||
Softmax1,
|
||||
Correction,
|
||||
MMA,
|
||||
Load,
|
||||
Epilogue,
|
||||
Empty
|
||||
};
|
||||
|
||||
static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
if (warp_idx == 0) return WarpRole::Softmax0; // 0 - 3
|
||||
if (warp_idx == 1) return WarpRole::MMA; // 12
|
||||
if (warp_idx == 2 || warp_idx == 3) return WarpRole::Load; // 13
|
||||
if (warp_idx == 4) return WarpRole::Softmax1; // 4 - 7
|
||||
if (warp_idx == 8) return WarpRole::Correction; // 8 - 11
|
||||
return WarpRole::Empty; // 15
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = 1;
|
||||
static const int NumWarpsCorrection = 1;
|
||||
static const int NumWarpsEpilogue = 0;
|
||||
static const int NumWarpsLoad = 2;
|
||||
|
||||
static const int NumRegsSoftmax = 192;
|
||||
static const int NumRegsCorrection = 104;
|
||||
static const int NumRegsOther = 248;
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = 12;
|
||||
|
||||
};
|
||||
|
||||
template<
|
||||
class ProblemShapeIn,
|
||||
class CollectiveMainloop,
|
||||
class CollectiveEpilogue,
|
||||
class TileScheduler,
|
||||
class KernelSchedule = Sm100FmhaGenKernelWarpspecializedSchedule
|
||||
>
|
||||
struct Sm100FmhaGenKernelWarpspecialized {
|
||||
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using ProblemShape = decltype(replace<0>(ProblemShapeIn{}, 0));
|
||||
|
||||
using WarpRole = typename KernelSchedule::WarpRole;
|
||||
|
||||
constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) {
|
||||
return KernelSchedule::warp_idx_to_WarpRole(warp_idx);
|
||||
}
|
||||
|
||||
static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax;
|
||||
static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection;
|
||||
static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue;
|
||||
static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad;
|
||||
|
||||
static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax;
|
||||
static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection;
|
||||
static const int NumRegsOther = KernelSchedule::NumRegsOther;
|
||||
static const int NumRegsEmpty = 24;
|
||||
|
||||
static const int NumWarps = KernelSchedule::NumWarps;
|
||||
|
||||
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||
|
||||
using TmemAllocator = cute::TMEM::Allocator1Sm;
|
||||
|
||||
struct SharedStorage {
|
||||
typename CollectiveMainloop::TensorStorage mainloop;
|
||||
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||
|
||||
struct PipelineStorage {
|
||||
alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q;
|
||||
alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0;
|
||||
alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr;
|
||||
alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi;
|
||||
alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01;
|
||||
} pipelines;
|
||||
|
||||
uint32_t tmem_base_ptr;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
using StrideQOrig = typename CollectiveMainloop::StrideQOrig;
|
||||
using StrideOOrig = typename CollectiveMainloop::StrideOOrig;
|
||||
using StrideQ = typename CollectiveMainloop::StrideQ;
|
||||
using StrideO = typename CollectiveMainloop::StrideO;
|
||||
using StrideCacheK = typename CollectiveMainloop::StrideCacheK;
|
||||
using StrideCacheV = typename CollectiveMainloop::StrideCacheV;
|
||||
using StrideNewK = typename CollectiveMainloop::StrideNewK;
|
||||
using StrideNewV = typename CollectiveMainloop::StrideNewV;
|
||||
using Element = typename CollectiveMainloop::Element;
|
||||
using ElementAcc = typename CollectiveMainloop::ElementAcc;
|
||||
using ElementOut = typename CollectiveMainloop::ElementOut;
|
||||
|
||||
struct Arguments {
|
||||
// _1, max_seqlen_k, head_dim, ((h_g, h_kv), b)
|
||||
ProblemShapeIn problem_shape;
|
||||
const int* seqlen_kv;
|
||||
const int* cache_batch_idx;
|
||||
|
||||
const Element* ptr_q; // 1 x D x (H x B)
|
||||
StrideQOrig dQ;
|
||||
const Element* ptr_new_k; // 1 x D x (H x B)
|
||||
StrideNewK dNewK;
|
||||
const Element* ptr_new_v; // 1 x D x (H x B)
|
||||
StrideNewV dNewV;
|
||||
|
||||
Element* ptr_cache_k; // seqlen_max x D x (H x B)
|
||||
StrideCacheK dCacheK;
|
||||
Element* ptr_cache_v; // seqlen_max x D x (H x B)
|
||||
StrideCacheV dCacheV;
|
||||
ElementOut* ptr_o; // 1 x D x (H x B)
|
||||
StrideOOrig dO;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
|
||||
ElementAcc scale_softmax = 0.0f;
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ProblemShape problem_shape;
|
||||
const int* seqlen_kv;
|
||||
typename CollectiveMainloop::Params mainloop;
|
||||
typename CollectiveEpilogue::Params epilogue;
|
||||
typename TileScheduler::Params tile_scheduler;
|
||||
};
|
||||
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||
return block;
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
ProblemShape problem_shape = replace<0>(args.problem_shape, static_cast<int>(get<0>(args.problem_shape)));
|
||||
CUTE_STATIC_ASSERT_V(get<0>(args.problem_shape) == _1{});
|
||||
StrideQ dQ = replace<0>(args.dQ, 0);
|
||||
StrideO dO = replace<0>(args.dO, 0);
|
||||
get<0>(problem_shape) = get<3,0,0>(args.problem_shape);
|
||||
get<3,0,0>(problem_shape) = 1;
|
||||
get<0>(dQ) = get<2,0,0>(dQ);
|
||||
get<0>(dO) = get<2,0,0>(dO);
|
||||
|
||||
typename CollectiveMainloop::Arguments mainloop_args {
|
||||
{
|
||||
args.cache_batch_idx,
|
||||
args.ptr_q, dQ,
|
||||
args.ptr_new_k, args.dNewK,
|
||||
args.ptr_new_v, args.dNewV,
|
||||
args.ptr_cache_k, args.dCacheK,
|
||||
args.ptr_cache_v, args.dCacheV,
|
||||
},
|
||||
args.scale_softmax
|
||||
};
|
||||
|
||||
typename CollectiveEpilogue::Arguments epilogue_args {
|
||||
args.ptr_o, dO,
|
||||
};
|
||||
|
||||
return Params{
|
||||
problem_shape,
|
||||
args.seqlen_kv,
|
||||
CollectiveMainloop::to_underlying_arguments(problem_shape, mainloop_args, workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(problem_shape, epilogue_args, workspace),
|
||||
TileScheduler::to_underlying_arguments(problem_shape, args.hw_info, ClusterShape{}, TileShape{})
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) {
|
||||
ProblemShape result = problem_shape;
|
||||
get<1>(result) = params.seqlen_kv[batch_idx];
|
||||
if (params.mainloop.load.ptr_new_k != nullptr) {
|
||||
get<1>(result) += 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||
|
||||
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
auto role = warp_idx_to_WarpRole(warp_idx);
|
||||
uint32_t lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (role == WarpRole::Load && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
}
|
||||
|
||||
if (role == WarpRole::Epilogue && lane_predicate) {
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_q_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineQ pipeline_load_q(
|
||||
shared_storage.pipelines.load_q,
|
||||
pipeline_load_q_params,
|
||||
ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params;
|
||||
if (role == WarpRole::Load) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_load_kv_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineKV pipeline_load_kv(
|
||||
shared_storage.pipelines.load_kv,
|
||||
pipeline_load_kv_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s0(
|
||||
shared_storage.pipelines.mma_s0,
|
||||
pipeline_mma_s0_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineS pipeline_mma_s1(
|
||||
shared_storage.pipelines.mma_s1,
|
||||
pipeline_mma_s1_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params;
|
||||
if (role == WarpRole::Softmax0) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s0_corr(
|
||||
shared_storage.pipelines.s0_corr,
|
||||
pipeline_s0_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params;
|
||||
if (role == WarpRole::Softmax1) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineC pipeline_s1_corr(
|
||||
shared_storage.pipelines.s1_corr,
|
||||
pipeline_s1_corr_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params;
|
||||
if (role == WarpRole::MMA) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineO pipeline_mma_corr(
|
||||
shared_storage.pipelines.mma_corr,
|
||||
pipeline_mma_corr_params,
|
||||
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
|
||||
|
||||
typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params;
|
||||
if (role == WarpRole::Correction) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer;
|
||||
}
|
||||
if (role == WarpRole::Epilogue) {
|
||||
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
|
||||
}
|
||||
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
|
||||
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
|
||||
shared_storage.pipelines.corr_epi,
|
||||
pipeline_corr_epi_params,
|
||||
/*barrier init*/ cute::true_type{});
|
||||
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01;
|
||||
params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0;
|
||||
params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp;
|
||||
typename CollectiveMainloop::OrderBarrierSoftmax order_s01(
|
||||
shared_storage.pipelines.order_s01, params_order_s01);
|
||||
|
||||
TmemAllocator tmem_allocator;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
pipeline_load_q.init_masks(ClusterShape{});
|
||||
pipeline_load_kv.init_masks(ClusterShape{});
|
||||
pipeline_mma_s0.init_masks(ClusterShape{});
|
||||
pipeline_mma_s1.init_masks(ClusterShape{});
|
||||
pipeline_mma_corr.init_masks(ClusterShape{});
|
||||
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state;
|
||||
typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineQ>();
|
||||
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state;
|
||||
typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineKV>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state;
|
||||
typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineS>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineC>();
|
||||
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state;
|
||||
typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineE>();
|
||||
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state;
|
||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||
|
||||
CollectiveMainloop mainloop;
|
||||
CollectiveEpilogue epilogue(params.epilogue);
|
||||
|
||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||
warpgroup_reg_set<NumRegsSoftmax>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_softmax_0 = role == WarpRole::Softmax0;
|
||||
|
||||
mainloop.softmax(
|
||||
is_softmax_0 ? 0 : 1, blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1,
|
||||
is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state,
|
||||
is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr,
|
||||
is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state,
|
||||
order_s01
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Correction) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.correction(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||
epilogue
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
if constexpr (NumWarpsEpilogue == 0) {
|
||||
static_assert(NumWarpsCorrection == 1);
|
||||
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::MMA) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||
__syncwarp();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
mainloop.mma(
|
||||
blk_coord,
|
||||
params.mainloop, logical_problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_consumer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_consumer_state,
|
||||
pipeline_mma_s0, pipeline_mma_s0_producer_state,
|
||||
pipeline_mma_s1, pipeline_mma_s1_producer_state,
|
||||
pipeline_mma_corr, pipeline_mma_corr_producer_state
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Load) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mainloop.load(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.mainloop, params.problem_shape,
|
||||
shared_storage.mainloop,
|
||||
pipeline_load_q, pipeline_load_q_producer_state,
|
||||
pipeline_load_kv, pipeline_load_kv_producer_state
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
else if (role == WarpRole::Epilogue) {
|
||||
warpgroup_reg_set<NumRegsOther>();
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
|
||||
auto logical_problem_shape = apply_batch(params,
|
||||
params.problem_shape, get<2,1>(blk_coord));
|
||||
|
||||
if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
epilogue.store(
|
||||
blk_coord, logical_problem_shape,
|
||||
params.epilogue, params.problem_shape,
|
||||
shared_storage.epilogue,
|
||||
pipeline_corr_epi, pipeline_corr_epi_consumer_state
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
static_assert(NumWarpsEpilogue <= 1);
|
||||
if constexpr (NumWarpsEpilogue == 1) {
|
||||
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
|
||||
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
|
||||
}
|
||||
|
||||
}
|
||||
else if (role == WarpRole::Empty) {
|
||||
warpgroup_reg_set<NumRegsEmpty>();
|
||||
|
||||
/* no-op, donate regs and exit */
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
187
examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp
Normal file
187
examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp
Normal file
@ -0,0 +1,187 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ElementAcc,
|
||||
class ProblemShape,
|
||||
class TensorQ,
|
||||
class TensorNewK,
|
||||
class TensorNewV,
|
||||
class TensorCacheK,
|
||||
class TensorCacheV,
|
||||
class TensorO
|
||||
>
|
||||
void __global__ fmha_fwd_gen_reference_kernel(
|
||||
ProblemShape problem_shape,
|
||||
const int* seqlen_kv, const int* cache_batch_idx,
|
||||
TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV,
|
||||
TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) {
|
||||
|
||||
using namespace cute;
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAcc* mS = reinterpret_cast<ElementAcc*>(mS_mem);
|
||||
|
||||
float scale = 1.0f / std::sqrt(float(get<2>(problem_shape)));
|
||||
|
||||
if (mNewK.data() != nullptr) {
|
||||
// 1. copy in new_k to cache
|
||||
for (int idx_h = blockIdx.x; idx_h < size<3,0,1>(problem_shape); idx_h += gridDim.x) {
|
||||
for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) {
|
||||
int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b;
|
||||
for (int idx_d = threadIdx.x; idx_d < size<2>(problem_shape); idx_d += blockDim.x) {
|
||||
mCacheK(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) =
|
||||
mNewK(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b));
|
||||
mCacheV(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) =
|
||||
mNewV(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. compute attention
|
||||
for (int idx_h_kv = blockIdx.x; idx_h_kv < size<3,0,1>(problem_shape); idx_h_kv += gridDim.x) {
|
||||
for (int idx_h_qo = blockIdx.y; idx_h_qo < size<3,0,0>(problem_shape); idx_h_qo += gridDim.y) {
|
||||
int idx_h = idx_h_qo + size<3,0,0>(problem_shape) * idx_h_kv;
|
||||
for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) {
|
||||
int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b;
|
||||
const int kDim = 128;
|
||||
ElementAcc reg_o[kDim] = {0};
|
||||
ElementAcc row_max = -INFINITY;
|
||||
ElementAcc row_sum = 0;
|
||||
auto iteration = [&](auto const& tK, auto const& tV) {
|
||||
ElementAcc reg_s = 0;
|
||||
for (int idx_d = 0; idx_d < kDim; idx_d++) {
|
||||
ElementAcc eQ = mQ(_0{}, idx_d, make_coord(idx_h, idx_b));
|
||||
ElementAcc eK = tK(idx_d);
|
||||
reg_s += eQ * eK;
|
||||
}
|
||||
|
||||
ElementAcc old_row_max = row_max;
|
||||
row_max = std::max(row_max, reg_s);
|
||||
|
||||
ElementAcc adjustment = std::exp(scale * (old_row_max - row_max));
|
||||
row_sum *= adjustment;
|
||||
for (int idx_d = 0; idx_d < kDim; idx_d++) {
|
||||
reg_o[idx_d] *= adjustment;
|
||||
}
|
||||
|
||||
ElementAcc reg_p = std::exp(scale * (reg_s - row_max));
|
||||
row_sum += reg_p;
|
||||
|
||||
for (int idx_d = 0; idx_d < kDim; idx_d++) {
|
||||
ElementAcc eV = tV(idx_d);
|
||||
reg_o[idx_d] += reg_p * eV;
|
||||
}
|
||||
};
|
||||
|
||||
for (int idx_s = threadIdx.x; idx_s < seqlen_kv[idx_b]; idx_s += blockDim.x) {
|
||||
iteration(mCacheK(idx_s, _, make_coord(idx_h, idx_b_kv)), mCacheV(idx_s, _, make_coord(idx_h, idx_b_kv)));
|
||||
}
|
||||
|
||||
if (mNewK.data() != nullptr && threadIdx.x == 0) {
|
||||
iteration(mNewK(_0{}, _, make_coord(idx_h, idx_b)), mNewV(_0{}, _, make_coord(idx_h, idx_b)));
|
||||
}
|
||||
|
||||
mS[threadIdx.x] = row_max;
|
||||
__syncthreads();
|
||||
float old_row_max = row_max;
|
||||
for (int i = 0; i < blockDim.x; i++) {
|
||||
row_max = std::max(row_max, mS[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
ElementAcc adjustment = std::exp(scale * (old_row_max - row_max));
|
||||
row_sum *= adjustment;
|
||||
for (int idx_d = 0; idx_d < kDim; idx_d++) {
|
||||
reg_o[idx_d] *= adjustment;
|
||||
}
|
||||
mS[threadIdx.x] = row_sum;
|
||||
__syncthreads();
|
||||
|
||||
row_sum = 0;
|
||||
for (int i = 0; i < blockDim.x; i++) {
|
||||
row_sum += mS[i];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int idx_d = 0; idx_d < kDim; idx_d++) {
|
||||
mS[idx_d] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_d = 0; idx_d < kDim; idx_d++) {
|
||||
reg_o[idx_d] /= row_sum;
|
||||
atomicAdd(&mS[idx_d], reg_o[idx_d]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int idx_d = threadIdx.x; idx_d < kDim; idx_d += blockDim.x) {
|
||||
|
||||
// printf("O[%d,%d,%d] = %f\n", idx_d, idx_h, idx_b, mS[idx_d]);
|
||||
mO(_0{}, idx_d, make_coord(idx_h, idx_b)) = static_cast<typename TensorO::value_type>(mS[idx_d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class ElementAcc,
|
||||
class ProblemShape,
|
||||
class TensorQ,
|
||||
class TensorNewK,
|
||||
class TensorNewV,
|
||||
class TensorCacheK,
|
||||
class TensorCacheV,
|
||||
class TensorO
|
||||
>
|
||||
void fmha_fwd_gen_reference(
|
||||
ProblemShape problem_shape,
|
||||
const int* seqlen_kv, const int* cache_batch_idx,
|
||||
TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV,
|
||||
TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(get<3,0,1>(problem_shape), get<3,0,0>(problem_shape), get<3,1>(problem_shape));
|
||||
dim3 block(128);
|
||||
int shared_mem = int(sizeof(ElementAcc)) * std::max<int>(128, block.x);
|
||||
assert(get<2>(problem_shape) == 128);
|
||||
fmha_fwd_gen_reference_kernel<ElementAcc><<<grid, block, shared_mem>>>(
|
||||
problem_shape, seqlen_kv, cache_batch_idx,
|
||||
mQ, mNewK, mNewV, mCacheK, mCacheV, mO
|
||||
);
|
||||
}
|
||||
163
examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp
Normal file
163
examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp
Normal file
@ -0,0 +1,163 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "collective/fmha_fusion.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShapeIn,
|
||||
class TensorQ,
|
||||
class TensorK,
|
||||
class TensorV,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Mask
|
||||
>
|
||||
void __global__ fmha_reference_kernel(
|
||||
ProblemShapeIn problem_shape_in,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Mask mask) {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::collective;
|
||||
|
||||
using Element = typename TensorO::value_type;
|
||||
using ElementAccumulator = typename TensorLSE::value_type;
|
||||
|
||||
extern __shared__ char mS_mem[];
|
||||
ElementAccumulator* mS = reinterpret_cast<ElementAccumulator*>(mS_mem);
|
||||
|
||||
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||
|
||||
auto id = make_identity_tensor(make_shape(1, 1));
|
||||
for (int idx_L = blockIdx.y; idx_L < size<3>(problem_shape_in); idx_L += gridDim.y) {
|
||||
for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) {
|
||||
|
||||
auto coord_L = idx2crd(idx_L, shape<3>(problem_shape_in));
|
||||
auto coord_in = cute::make_tuple(idx_Q, _0{}, _0{}, coord_L);
|
||||
auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<3,1>(coord_in));
|
||||
|
||||
if (get<0,0>(coord) >= get<0>(problem_shape)) continue;
|
||||
|
||||
int offset_Q = 0;
|
||||
if constexpr (rank<0>(decltype(coord){}) == 2) {
|
||||
offset_Q = get<0,1>(coord);
|
||||
}
|
||||
|
||||
int offset_K = 0;
|
||||
if constexpr (rank<1>(decltype(coord){}) == 2) {
|
||||
offset_K = get<1,1>(coord);
|
||||
}
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_D = 0; idx_D < size<2>(problem_shape); idx_D++) {
|
||||
ElementAccumulator eQ = mQ(idx_Q + offset_Q, idx_D, idx_L);
|
||||
ElementAccumulator eK = mK(idx_K + offset_K, idx_D, idx_L);
|
||||
acc += eQ * eK;
|
||||
}
|
||||
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||
frag(0) = acc;
|
||||
mask.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||
mS[idx_K] = frag(0);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAccumulator maxS = -std::numeric_limits<ElementAccumulator>::infinity();
|
||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
||||
maxS = std::max<ElementAccumulator>(maxS, mS[idx_K]);
|
||||
}
|
||||
if (maxS == -std::numeric_limits<ElementAccumulator>::infinity()) maxS = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) {
|
||||
mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
ElementAccumulator sum = 0;
|
||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
||||
sum += mS[idx_K];
|
||||
}
|
||||
|
||||
ElementAccumulator scale = 1.0f / sum;
|
||||
|
||||
for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) {
|
||||
ElementAccumulator acc = 0;
|
||||
for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) {
|
||||
ElementAccumulator eV = mV(idx_K + offset_K, idx_D, idx_L);
|
||||
ElementAccumulator eK = static_cast<Element>(mS[idx_K]);
|
||||
acc += eK * eV;
|
||||
}
|
||||
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + maxS;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class ProblemShapeIn,
|
||||
class TensorQ,
|
||||
class TensorK,
|
||||
class TensorV,
|
||||
class TensorO,
|
||||
class TensorLSE,
|
||||
class Mask
|
||||
>
|
||||
void fmha_reference(
|
||||
ProblemShapeIn problem_shape_in,
|
||||
TensorQ mQ, TensorK mK, TensorV mV,
|
||||
TensorO mO, TensorLSE mLSE,
|
||||
Mask mask) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
dim3 grid(size<0>(mO), size<2>(mO), 1);
|
||||
dim3 block(256);
|
||||
int shared_mem = size<0>(mK) * int(sizeof(typename TensorLSE::value_type));
|
||||
fmha_reference_kernel<<<grid, block, shared_mem>>>(problem_shape_in, mQ, mK, mV, mO, mLSE, mask);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
180
examples/77_blackwell_fmha/reference/reference_abs_error.hpp
Normal file
180
examples/77_blackwell_fmha/reference/reference_abs_error.hpp
Normal file
@ -0,0 +1,180 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <math.h>
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct DeviceAllocation {
|
||||
T* ptr_ = nullptr;
|
||||
size_t offset_ = 0;
|
||||
size_t size_ = 0;
|
||||
|
||||
DeviceAllocation(DeviceAllocation const&) = delete;
|
||||
DeviceAllocation& operator=(DeviceAllocation const&) = delete;
|
||||
|
||||
DeviceAllocation() = default;
|
||||
DeviceAllocation(size_t size) { reset(size); }
|
||||
~DeviceAllocation() { reset(); }
|
||||
|
||||
void reset(size_t size, size_t offset=0) {
|
||||
reset();
|
||||
auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset));
|
||||
assert(ret == cudaSuccess);
|
||||
size_ = size;
|
||||
offset_ = offset;
|
||||
}
|
||||
|
||||
T* get() {
|
||||
return ptr_ + offset_;
|
||||
}
|
||||
|
||||
const T* get() const {
|
||||
return ptr_ + offset_;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (ptr_ != nullptr) {
|
||||
auto ret = cudaFree(ptr_);
|
||||
assert(ret == cudaSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const { return size_; }
|
||||
|
||||
void copy_from_host(const T* ptr, size_t sz) {
|
||||
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
||||
assert(ret == cudaSuccess);
|
||||
}
|
||||
|
||||
void copy_from_device(const T* ptr, size_t sz) {
|
||||
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
||||
assert(ret == cudaSuccess);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Element>
|
||||
__global__ void reference_abs_diff_kernel(
|
||||
Element* data, Element* data_ref, size_t count,
|
||||
double* max_diff, double* sum_diff,
|
||||
bool print_diff ) {
|
||||
|
||||
double thread_max_diff = 0;
|
||||
double thread_sum_diff = 0;
|
||||
|
||||
__shared__ double block_max_diff;
|
||||
__shared__ double block_sum_diff;
|
||||
|
||||
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
double diff = fabs(data[i] - data_ref[i]);
|
||||
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||
thread_max_diff = fmax(diff, thread_max_diff);
|
||||
thread_sum_diff += diff;
|
||||
}
|
||||
|
||||
for (int i = 0; i < blockDim.x; i++) {
|
||||
if (i == threadIdx.x) {
|
||||
if (i == 0) {
|
||||
block_max_diff = thread_max_diff;
|
||||
block_sum_diff = thread_sum_diff;
|
||||
}
|
||||
else {
|
||||
block_max_diff = fmax(block_max_diff, thread_max_diff);
|
||||
block_sum_diff += thread_sum_diff;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
atomicAdd(sum_diff, block_sum_diff);
|
||||
|
||||
for (;;) {
|
||||
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
|
||||
double prev_diff = reinterpret_cast<double const&>(prev);
|
||||
double new_max_diff = fmax(block_max_diff, prev_diff);
|
||||
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
|
||||
if (found == prev) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Element>
|
||||
void reference_abs_diff(
|
||||
DeviceAllocation<Element> const& data,
|
||||
DeviceAllocation<Element> const& data_ref,
|
||||
double& max_diff, double& mean_diff) {
|
||||
|
||||
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
|
||||
|
||||
DeviceAllocation<double> result;
|
||||
result.reset(2);
|
||||
assert(data.size() == data_ref.size());
|
||||
|
||||
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Memset failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 block(256, 1, 1);
|
||||
dim3 grid(1024, 1, 1);
|
||||
reference_abs_diff_kernel<<<block, grid>>>(
|
||||
data.get(), data_ref.get(), data.size(),
|
||||
result.get(), result.get() + 1, kPrintDiff);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Difference kernel failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
double result_host[2];
|
||||
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "Copy failed. Last CUDA error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
max_diff = mean_diff = 1e20;
|
||||
return;
|
||||
}
|
||||
|
||||
max_diff = result_host[0];
|
||||
mean_diff = result_host[1] / static_cast<double>(data.size());
|
||||
}
|
||||
@ -146,6 +146,14 @@ foreach(EXAMPLE
|
||||
64_ada_fp8_gemm_grouped
|
||||
65_distributed_gemm
|
||||
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
|
||||
70_blackwell_gemm
|
||||
71_blackwell_gemm_with_collective_builder
|
||||
72_blackwell_narrow_precision_gemm
|
||||
73_blackwell_gemm_preferred_cluster
|
||||
74_blackwell_gemm_streamk
|
||||
75_blackwell_grouped_gemm
|
||||
76_blackwell_conv
|
||||
77_blackwell_fmha
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
293
examples/README.md
Normal file
293
examples/README.md
Normal file
@ -0,0 +1,293 @@
|
||||
# CUTLASS - Programming Examples
|
||||
|
||||
* [00_basic_gemm](00_basic_gemm/)
|
||||
|
||||
launches a basic GEMM with single precision inputs and outputs
|
||||
|
||||
* [01_cutlass_utilities](01_cutlass_utilities/)
|
||||
|
||||
demonstrates CUTLASS Utilities for allocating and initializing tensors
|
||||
|
||||
* [02_dump_reg_smem](02_dump_reg_smem/)
|
||||
|
||||
debugging utilities for printing register and shared memory contents
|
||||
|
||||
* [03_visualize_layout](03_visualize_layout/)
|
||||
|
||||
utility for visualizing all layout functions in CUTLASS
|
||||
|
||||
* [04_tile_iterator](04_tile_iterator/)
|
||||
|
||||
example demonstrating an iterator over tiles in memory
|
||||
|
||||
* [05_batched_gemm](05_batched_gemm/)
|
||||
|
||||
example demonstrating CUTLASS's batched strided GEMM operation
|
||||
|
||||
* [06_splitK_gemm](06_splitK_gemm/)
|
||||
|
||||
example demonstrating CUTLASS's Split-K parallel reduction kernel
|
||||
|
||||
* [07_volta_tensorop_gemm](07_volta_tensorop_gemm/)
|
||||
|
||||
example demonstrating mixed precision GEMM using Volta Tensor Cores
|
||||
|
||||
* [08_turing_tensorop_gemm](08_turing_tensorop_gemm/)
|
||||
|
||||
example demonstrating integer GEMM using Turing Tensor Cores
|
||||
|
||||
* [09_turing_tensorop_conv2dfprop](09_turing_tensorop_conv2dfprop/)
|
||||
|
||||
example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores
|
||||
|
||||
* [10_planar_complex](10_planar_complex/)
|
||||
|
||||
example demonstrating planar complex GEMM kernels
|
||||
|
||||
* [11_planar_complex_array](11_planar_complex_array/)
|
||||
|
||||
example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
|
||||
* [12_gemm_bias_relu](12_gemm_bias_relu/)
|
||||
|
||||
example demonstrating GEMM fused with bias and relu
|
||||
|
||||
* [13_two_tensor_op_fusion](13_two_tensor_op_fusion/)
|
||||
|
||||
example demonstrating two GEMMs or convolutions fused in one kernel
|
||||
|
||||
* [14_ampere_tf32_tensorop_gemm](14_ampere_tf32_tensorop_gemm/)
|
||||
|
||||
example demonstrating FP32 GEMM with implicit TF32 conversion
|
||||
|
||||
* [15_ampere_sparse_tensorop_gemm](15_ampere_sparse_tensorop_gemm/)
|
||||
|
||||
example demonstrating usage of Sparse Tensor cores
|
||||
|
||||
* [16_ampere_tensorop_conv2dfprop](16_ampere_tensorop_conv2dfprop/)
|
||||
|
||||
example demonstrating forward convolution on tensors of layout NHWC
|
||||
|
||||
* [17_fprop_per_channel_bias](17_fprop_per_channel_bias/)
|
||||
|
||||
example demonstrating convolution fused with per channel bias and relu
|
||||
|
||||
* [18_ampere_fp64_tensorop_affine2_gemm](18_ampere_fp64_tensorop_affine2_gemm/)
|
||||
|
||||
example demonstrating Affine-2 GEMM
|
||||
|
||||
* [19_tensorop_canonical](19_tensorop_canonical/)
|
||||
|
||||
Canonical GEMM using tensor cores
|
||||
|
||||
* [20_simt_canonical](20_simt_canonical/)
|
||||
|
||||
Canonical GEMM using SIMT
|
||||
|
||||
* [21_quaternion_gemm](21_quaternion_gemm/)
|
||||
|
||||
example demonstrating Quaternion GEMM computations
|
||||
|
||||
* [22_quaternion conv](22_quaternion_conv/)
|
||||
|
||||
example demonstrating Quaternion convolution
|
||||
|
||||
* [23_ampere_gemm_operand_reduction_fusion](23_ampere_gemm_operand_reduction_fusion/)
|
||||
|
||||
example demonstrating how to reduce one of the operands of the GEMM along the k-dimension when computing GEMM
|
||||
|
||||
* [24_gemm_grouped](24_gemm_grouped/)
|
||||
|
||||
example demonstrating batch of GEMM operations with distinct problem sizes
|
||||
|
||||
* [25_ampere_fprop_mainloop_fusion](25_ampere_fprop_mainloop_fusion/)
|
||||
|
||||
example demonstrating fusing activation's per channel scale+bias+relu into the fgrad mainloop
|
||||
|
||||
* [26_ampere_wgrad_mainloop_fusion](26_ampere_wgrad_mainloop_fusion/)
|
||||
|
||||
example demonstrating fusing activation's per channel scale+bias+relu into the wgrad mainloop
|
||||
|
||||
* [27_ampere_3xtf32_fast_accurate_tensorop_gemm](27_ampere_3xtf32_fast_accurate_tensorop_gemm/)
|
||||
|
||||
example demonstrating emulation of a fast accurate SGEMM with TF32 operations
|
||||
|
||||
* [28_ampere_3xtf32_fast_accurate_tensorop_fprop](28_ampere_3xtf32_fast_accurate_tensorop_fprop/)
|
||||
|
||||
example demonstrating emulation of a fast accurate FP32 convolution with TF32 operation
|
||||
|
||||
* [29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm](29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/)
|
||||
|
||||
example demonstrating emulation of a fast accurate CGEMM with TF32 operation
|
||||
|
||||
* [30_wgrad_split_k](30_wgrad_split_k/)
|
||||
|
||||
example demonstrating how to compute conv2d gradient with respect to weight (wgrad) together with split-K
|
||||
|
||||
* [31_basic_syrk](31_basic_syrk/)
|
||||
|
||||
example demonstrating Symmetric Rank-K update
|
||||
|
||||
* [32_basic_trmm](32_basic_trmm/)
|
||||
|
||||
example demonstrating Triangular Matrix-Matrix multiplication
|
||||
|
||||
* [33_ampere_3xtf32_tensorop_symm](33_ampere_3xtf32_tensorop_symm/)
|
||||
|
||||
example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation
|
||||
|
||||
* [34_transposed_conv2d](34_transposed_conv2d/)
|
||||
|
||||
example demonstrating how to compute 2d transposed convolution, also known as deconvolution, using CUTLASS conv2d Dgrad kernels
|
||||
|
||||
* [35_gemm_softmax](35_gemm_softmax/)
|
||||
|
||||
example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores
|
||||
|
||||
* [36_gather_scatter_fusion](36_gather_scatter_fusion/)
|
||||
|
||||
example demonstrating fuses gather before GEMM and scatter after GEMM into the same GEMM kernel
|
||||
|
||||
* [37_gemm_layernorm_gemm_fusion](37_gemm_layernorm_gemm_fusion/)
|
||||
|
||||
example demonstrating fuses gemm->layernorm->gemm into one kernel.
|
||||
|
||||
* [38_syr2k_grouped](38_syr2k_grouped/)
|
||||
|
||||
example demonstrating a batch of SYR2K operations with distinct problem sizes
|
||||
|
||||
* [39_gemm_permute](39_gemm_permute/)
|
||||
|
||||
example demonstrating batched GEMM operations with output results permuted as reshaped tensors
|
||||
|
||||
* [40_cutlass_py](40_cutlass_py/)
|
||||
|
||||
example demonstrating CUTLASS with Python interface
|
||||
|
||||
* [41_multi_head_attention](41_multi_head_attention/)
|
||||
|
||||
example demonstrating attention example with non-fixed sequence length input
|
||||
|
||||
* [42_ampere_tensorop_group_conv](42_ampere_tensorop_group_conv/)
|
||||
|
||||
example demonstrating how to run group convolution kernels using functions and data structures provided by CUTLASS using tensor cores
|
||||
|
||||
* [43_ell_block_sparse_gemm](43_ell_block_sparse_gemm/)
|
||||
|
||||
example demonstrating a Block-Ell sparse gemm
|
||||
|
||||
* [44_fused_multi_head_attention](44_fused_multi_head_attention/)
|
||||
|
||||
example demonstrating fused multihead attention (fixed & variable) using shared memory
|
||||
|
||||
* [45_dual_gemm](45_dual_gemm/)
|
||||
|
||||
example demonstrating how to fuse two GEMMs sharing the same left input matrix into one kernel
|
||||
|
||||
* [46_depthwise_simt_conv2dfprop](46_depthwise_simt_conv2dfprop/)
|
||||
|
||||
example demonstrating depthwise 2d convolution kernels using functions and data structures provided by CUTLASS using SIMT instruction
|
||||
|
||||
* [47_ampere_gemm_universal_streamk](47_ampere_gemm_universal_streamk/)
|
||||
|
||||
example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
|
||||
"classic data-parallel" and "Split-K" decompositions.
|
||||
|
||||
* [48_hopper_warp_specialized_gemm](48_hopper_warp_specialized_gemm/)
|
||||
|
||||
Simple tensorop GEMM example using CUTLASS 3.0 APIs targeting NVIDIA Hopper architecture
|
||||
|
||||
* [49_hopper_gemm_schedules_with_collective_builder](49_hopper_gemm_schedules_with_collective_builder/)
|
||||
|
||||
Hopper GEMM example leveraging collective operation builders to showcase the builder API and the various kernel scheduled supported in CUTLASS 3.0 such as warp specialized persistent mainloops.
|
||||
|
||||
* [50_hopper_gemm_with_epilogue_swizzle](50_hopper_gemm_with_epilogue_swizzle/)
|
||||
|
||||
Hopper GEMM example to create a GEMM kernel with custom a collective mainloop and a custom vectorized epilogue.
|
||||
|
||||
* [51_hopper_gett](51_hopper_gett/)
|
||||
|
||||
Hopper GETT example illustrating the ease with which GETTs can be run due to CUTLASS 3.0's unified micro-kernels and CuTe's hierarchical layouts.
|
||||
|
||||
* [52_hopper_gather_scatter_fusion](52_hopper_gather_scatter_fusion/)
|
||||
|
||||
Hopper example that fuses gather before GEMM and scatter after GEMM into the same kernel
|
||||
|
||||
* [53_hopper_gemm_permute](53_hopper_gemm_permute/)
|
||||
|
||||
Hopper example demonstrating the fusion of tensor permutation operations with a GEMM kernel
|
||||
|
||||
* [54_hopper_fp8_warp_specialized_gemm](54_hopper_fp8_warp_specialized_gemm/)
|
||||
|
||||
Hopper example of instantiating and running an FP8 GEMM kernel
|
||||
|
||||
* [55_hopper_mixed_dtype_gemm](55_hopper_mixed_dtype_gemm/)
|
||||
|
||||
Hopper GEMM example with different A and B data types using CUTLASS 3.x APIs for DL kernels with fused dequantization.
|
||||
|
||||
* [56_hopper_ptr_array_batched_gemm](56_hopper_ptr_array_batched_gemm/)
|
||||
|
||||
Hopper Ptr-Array Batched GEMM example using CUTLASS 3.x API.
|
||||
|
||||
* [57_hopper_grouped_gemm](57_hopper_grouped_gemm/)
|
||||
|
||||
Hopper Grouped GEMM using CUTLASS 3.x API.
|
||||
|
||||
* [58_ada_fp8_gemm](58_ada_fp8_gemm/)
|
||||
|
||||
Ada GEMM kernel targetting Ada FP8 tensor cores via the CUTLASS 2.x API.
|
||||
|
||||
* [59_ampere_gather_scatter_conv](59_ampere_gather_scatter_conv/)
|
||||
|
||||
CuTe and CUTLASS 3.x based Ampere convolution fprop kernel capable of operating on both affine and gather/scatter tensors,
|
||||
showing how kernel authors can re-use CUTLASS 3.x collectives in their custom kernels.
|
||||
|
||||
* [61_hopper_gemm_with_topk_and_softmax](61_hopper_gemm_with_topk_and_softmax/)
|
||||
|
||||
Hopper GEMM kernel with Top-K and softmax epilogue fusion.
|
||||
|
||||
[//]: #
|
||||
|
||||
* [70_blackwell_gemm](70_blackwell_gemm)
|
||||
|
||||
Simple dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
* [71_blackwell_gemm_with_collective_builder](71_blackwell_gemm_with_collective_builder)
|
||||
|
||||
Blackwell SM100 GEMM example demonstrating compatible mainloop+epilogue builder schedules and epilogue visitor tree (EVT) construction
|
||||
|
||||
* [72a_blackwell_narrow_precision_gemm](72a_blackwell_narrow_precision_gemm)
|
||||
|
||||
Block-scaled dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
* [73_blackwell_gemm_preferred_cluster](73_blackwell_gemm_preferred_cluster/)
|
||||
|
||||
Blackwell SM100 GEMM kernel with preferred cluster feature.
|
||||
|
||||
* [74_blackwell_gemm_streamk](74_blackwell_gemm_streamk/)
|
||||
|
||||
Blackwell SM100 GEMM kernel using the Stream-K scheduler
|
||||
|
||||
* [75_blackwell_grouped_gemm](75_blackwell_grouped_gemm)
|
||||
|
||||
Blackwell SM100 grouped GEMM kernel
|
||||
|
||||
* [76_blackwell_conv](76_blackwell_conv/)
|
||||
|
||||
Simple convolution(fprop/dgrad/wgrad) example targeting NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
|
||||
|
||||
* [77_blackwell_fmha](77_blackwell_fmha)
|
||||
|
||||
Blackwell SM100 FMHA kernel
|
||||
|
||||
[//]: #
|
||||
|
||||
# CuTe - Programming Examples
|
||||
|
||||
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).
|
||||
|
||||
Additionally, CuTe's core layout and layout algebra have their own test cases within [cutlass/test/unit/cute/core/](../test/unit/cute/core/) that users might find useful as examples of CuTe.
|
||||
|
||||
# Python Interface Examples
|
||||
|
||||
Examples leveraging CUTLASS's [Python interface](../python/README.md) are located in [cutlass/examples/python](python/).
|
||||
108
include/cute/arch/cluster_sm100.hpp
Executable file
108
include/cute/arch/cluster_sm100.hpp
Executable file
@ -0,0 +1,108 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
|
||||
namespace cute {
|
||||
|
||||
|
||||
//
|
||||
// Cluster launch utility
|
||||
//
|
||||
CUTE_HOST
|
||||
bool
|
||||
initialize_preferred_cluster_launch(void const* const kernel_function,
|
||||
dim3 const& grid_dims,
|
||||
dim3 const& cluster_dims_preferred,
|
||||
dim3 const& cluster_dims_fallback)
|
||||
{
|
||||
//
|
||||
// Validate cluster_dims
|
||||
//
|
||||
|
||||
// Total number of cluster cannot be greater than 32 (hardware requirement)
|
||||
if (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z <= 0 ||
|
||||
cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z > 32) {
|
||||
std::cout << "Invalid preferred cluster dimensions: Attempting to init preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z
|
||||
<< ") [" << (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z) << "] which must be within (0,32]." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Total number of cluster cannot be greater than 32 (hardware requirement)
|
||||
if (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z <= 0 ||
|
||||
cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z > 32) {
|
||||
std::cout << "Invalid cluster dimensions: Attempting to init fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z
|
||||
<< ") [" << (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z) << "] which must be within (0,32]." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Total grid dimensions must be within (2^32, 2^16, 2^16)
|
||||
if (grid_dims.y > (1 << 16) || grid_dims.z > (1 << 16)) {
|
||||
std::cout << "Invalid grid dimensions: Attempting to init grid dimensions (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z
|
||||
<< ") which must be within (2^32, 2^16, 2^16)." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// grid_dims should be divisible by cluster_dims_preferred
|
||||
if (grid_dims.x % cluster_dims_preferred.x != 0 ||
|
||||
grid_dims.y % cluster_dims_preferred.y != 0 ||
|
||||
grid_dims.z % cluster_dims_preferred.z != 0) {
|
||||
std::cout << "Invalid grid dimensions: Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z
|
||||
<< ") does not divide Grid (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z << ")." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// cluster_dims_preferred should be divisible by cluster_dims_fallback
|
||||
if (cluster_dims_preferred.x % cluster_dims_fallback.x != 0 ||
|
||||
cluster_dims_preferred.y % cluster_dims_fallback.y != 0 ||
|
||||
cluster_dims_preferred.z % cluster_dims_fallback.z != 0) {
|
||||
std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z
|
||||
<< ") does not divide Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ")." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Both cluster dimenions should have the same depth
|
||||
if (cluster_dims_preferred.z != cluster_dims_fallback.z) {
|
||||
std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z
|
||||
<< ") and Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ") does not have the same depth." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // end namespace cute
|
||||
@ -48,3 +48,42 @@
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED))
|
||||
# define CUTE_ARCH_TMA_SM90_ENABLED
|
||||
# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED
|
||||
# define CUTE_ARCH_STSM_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED
|
||||
# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
# define CUTE_ARCH_LDSM_SM100A_ENABLED
|
||||
# define CUTE_ARCH_STSM_SM100A_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
# define CUTE_ARCH_TCGEN05_TMEM_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
# define CUTE_ARCH_TMA_SM100_ENABLED
|
||||
#endif
|
||||
|
||||
// {add, mul, fma}.f32x2 PTX
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED)
|
||||
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
7567
include/cute/arch/copy_sm100.hpp
Normal file
7567
include/cute/arch/copy_sm100.hpp
Normal file
File diff suppressed because it is too large
Load Diff
664
include/cute/arch/copy_sm100_tma.hpp
Normal file
664
include/cute/arch/copy_sm100_tma.hpp
Normal file
@ -0,0 +1,664 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2020 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/config.hpp>
|
||||
|
||||
#include <cute/arch/copy.hpp>
|
||||
#include <cute/arch/copy_sm90.hpp>
|
||||
namespace cute
|
||||
{
|
||||
|
||||
constexpr uint32_t Sm100MmaPeerBitMask = 0xFEFFFFFF;
|
||||
constexpr uint64_t Sm100MemDescDefault = uint64_t(0x1000000000000000);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// UTMA_LOAD : Initiates a TMA copy from global memory to shared memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_1D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
|
||||
[[maybe_unused]] void * smem_ptr,
|
||||
[[maybe_unused]] int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3}], [%2], %4;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_2D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
|
||||
[[maybe_unused]] void * smem_ptr,
|
||||
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4}], [%2], %5;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
|
||||
[[maybe_unused]] void * smem_ptr,
|
||||
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5}], [%2], %6;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
|
||||
using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;
|
||||
};
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_MULTICAST_1D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%4}], [%2], %3, %5;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask),
|
||||
"r"(crd0), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_MULTICAST_2D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%4, %5}], [%2], %3, %6;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_MULTICAST_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%4, %5, %6}], [%2], %3, %7;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_MULTICAST_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_MULTICAST_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask),
|
||||
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM0_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint,
|
||||
void * smem_ptr,
|
||||
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);
|
||||
}
|
||||
|
||||
using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n,
|
||||
uint16_t const& offset_w)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_n),
|
||||
"h"(offset_w), "l"(Sm100MemDescDefault)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n,
|
||||
uint16_t const& offset_w,
|
||||
uint16_t const& offset_h)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n),
|
||||
"h"(offset_w), "h"(offset_h), "l"(Sm100MemDescDefault)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n,
|
||||
uint16_t const& offset_w,
|
||||
uint16_t const& offset_h,
|
||||
uint16_t const& offset_d)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n),
|
||||
"h"(offset_w), "h"(offset_h), "h"(offset_d), "l"(Sm100MemDescDefault)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n,
|
||||
uint16_t const& offset_w)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr,
|
||||
coord_c, coord_w, coord_n,
|
||||
offset_w);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n,
|
||||
uint16_t const& offset_w,
|
||||
uint16_t const& offset_h)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr,
|
||||
coord_c, coord_w, coord_h, coord_n,
|
||||
offset_w, offset_h);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n,
|
||||
uint16_t const& offset_w,
|
||||
uint16_t const& offset_h,
|
||||
uint16_t const& offset_d)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr,
|
||||
coord_c, coord_w, coord_h, coord_d, coord_n,
|
||||
offset_w, offset_h, offset_d);
|
||||
}
|
||||
|
||||
using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n,
|
||||
uint16_t const& offset_w)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7, %8;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_n),
|
||||
"h"(offset_w),
|
||||
"h"(multicast_mask),
|
||||
"l"(Sm100MemDescDefault)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n,
|
||||
uint16_t const& offset_w,
|
||||
uint16_t const& offset_h)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9, %10;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n),
|
||||
"h"(offset_w), "h"(offset_h),
|
||||
"h"(multicast_mask),
|
||||
"l"(Sm100MemDescDefault)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n,
|
||||
uint16_t const& offset_w,
|
||||
uint16_t const& offset_h,
|
||||
uint16_t const& offset_d)
|
||||
{
|
||||
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
|
||||
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
|
||||
// Executed by both CTAs. Set peer bit to 0 so that the
|
||||
// transaction bytes will update CTA0's barrier.
|
||||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11, %12;"
|
||||
:
|
||||
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
|
||||
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n),
|
||||
"h"(offset_w), "h"(offset_h), "h"(offset_d),
|
||||
"h"(multicast_mask),
|
||||
"l"(Sm100MemDescDefault)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST
|
||||
{
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n,
|
||||
uint16_t const& offset_w)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask,
|
||||
smem_ptr,
|
||||
coord_c, coord_w, coord_n,
|
||||
offset_w);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n,
|
||||
uint16_t const& offset_w, uint16_t const& offset_h)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask,
|
||||
smem_ptr,
|
||||
coord_c, coord_w, coord_h, coord_n,
|
||||
offset_w, offset_h);
|
||||
}
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
|
||||
void * smem_ptr,
|
||||
int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n,
|
||||
uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d)
|
||||
{
|
||||
return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask,
|
||||
smem_ptr,
|
||||
coord_c, coord_w, coord_h, coord_d, coord_n,
|
||||
offset_w, offset_h, offset_d);
|
||||
}
|
||||
|
||||
using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // end namespace cute
|
||||
@ -140,6 +140,11 @@ enum class SmemSwizzleBits : uint8_t {
|
||||
|
||||
enum class SmemSwizzleBase : uint8_t {
|
||||
SWIZZLE_BASE_16B = 0,
|
||||
|
||||
SWIZZLE_BASE_32B = 1,
|
||||
SWIZZLE_BASE_32B_FLIP_8B = 2,
|
||||
SWIZZLE_BASE_64B = 3,
|
||||
|
||||
};
|
||||
|
||||
enum class OOBFill : uint8_t {
|
||||
@ -184,6 +189,14 @@ enum class CacheHintSm90 : uint64_t {
|
||||
EVICT_LAST = 0x14F0000000000000,
|
||||
};
|
||||
|
||||
|
||||
enum class CacheHintSm100 : uint64_t {
|
||||
EVICT_NORMAL = 0x1000000000000000,
|
||||
EVICT_FIRST = 0x12F0000000000000,
|
||||
EVICT_LAST = 0x14F0000000000000,
|
||||
};
|
||||
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
@ -195,6 +208,7 @@ to_CUtensorMapDataType() {
|
||||
if constexpr (is_same_v<T, uint8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (is_same_v<T, float_e4m3_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (is_same_v<T, float_e5m2_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else
|
||||
if constexpr (is_same_v<T, uint16_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else
|
||||
if constexpr (is_same_v<T, uint32_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else
|
||||
if constexpr (is_same_v<T, uint64_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else
|
||||
@ -205,6 +219,18 @@ to_CUtensorMapDataType() {
|
||||
if constexpr (is_same_v<T, double>) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else
|
||||
if constexpr (is_same_v<T, bfloat16_t>) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else
|
||||
if constexpr (is_same_v<T, tfloat32_t>) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else
|
||||
|
||||
if constexpr (is_same_v<T, float_e2m3_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, float_e3m2_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, float_e2m1_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;} else
|
||||
if constexpr (is_same_v<T, cutlass::detail::float_e2m1_unpacksmem_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, cutlass::detail::float_e2m3_unpacksmem_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, cutlass::detail::float_e3m2_unpacksmem_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, detail::type_erased_dynamic_float6_unpacksmem_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float6_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, detail::type_erased_dynamic_float4_unpacksmem_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float4_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else
|
||||
|
||||
{ static_assert(sizeof(T) < 0, "Unknown TMA Format!"); }
|
||||
}
|
||||
|
||||
@ -221,9 +247,21 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) {
|
||||
case SmemSwizzleBits::B64:
|
||||
assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits.");
|
||||
return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
#if (0)
|
||||
case SmemSwizzleBits::B128:
|
||||
assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 128B swizzle bits.");
|
||||
return CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
|
||||
#else
|
||||
case SmemSwizzleBits::B128:
|
||||
switch (b) {
|
||||
default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
|
||||
case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
|
||||
case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B;
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1157,6 +1157,17 @@ tma_store_arrive() {
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_desc_commit_group() {
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
asm volatile("cp.async.bulk.commit_group;");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete
|
||||
template <int Count>
|
||||
CUTE_HOST_DEVICE static void
|
||||
@ -1173,6 +1184,22 @@ tma_store_wait() {
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// Wait until all TMA descriptor previously issued are safe to be modified after tma_desc_commit_group()
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_desc_wait_group() {
|
||||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
||||
asm volatile(
|
||||
"cp.async.bulk.wait_group.read %0;"
|
||||
:
|
||||
: "n"(0)
|
||||
: "memory");
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// TMA_REDUCE_ADD : Initiates a TMA reduce-add from shared memory to global memory
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
42
include/cute/arch/mma_sm100.hpp
Normal file
42
include/cute/arch/mma_sm100.hpp
Normal file
@ -0,0 +1,42 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/config.hpp>
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
} // namespace cute
|
||||
652
include/cute/arch/mma_sm100_desc.hpp
Normal file
652
include/cute/arch/mma_sm100_desc.hpp
Normal file
@ -0,0 +1,652 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cinttypes>
|
||||
#endif
|
||||
|
||||
#include <cute/arch/config.hpp>
|
||||
|
||||
#include <cute/arch/mma.hpp>
|
||||
|
||||
#include <cute/container/bit_field.hpp>
|
||||
#include <cute/container/array.hpp> // cute::array
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// UMMA Descriptor and utilities
|
||||
|
||||
// UMMA enums and utilities
|
||||
namespace UMMA
|
||||
{
|
||||
|
||||
enum class Major : uint8_t {
|
||||
K = 0,
|
||||
MN = 1
|
||||
};
|
||||
|
||||
enum class ScaleIn : uint8_t {
|
||||
One = 0,
|
||||
Neg = 1
|
||||
};
|
||||
|
||||
enum class ScaleOut : uint8_t {
|
||||
Zero = 0,
|
||||
One = 1
|
||||
};
|
||||
|
||||
enum class Saturate : uint8_t {
|
||||
False = 0,
|
||||
True = 1
|
||||
};
|
||||
|
||||
enum class LayoutType : uint8_t {
|
||||
SWIZZLE_NONE = 0,
|
||||
SWIZZLE_128B_BASE32B = 1,
|
||||
SWIZZLE_128B = 2,
|
||||
SWIZZLE_64B = 4,
|
||||
SWIZZLE_32B = 6
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) {
|
||||
switch (t) {
|
||||
case LayoutType::SWIZZLE_NONE: return "SWIZZLE_NONE";
|
||||
case LayoutType::SWIZZLE_128B_BASE32B: return "SWIZZLE_128B_BASE32B";
|
||||
case LayoutType::SWIZZLE_128B: return "SWIZZLE_128B";
|
||||
case LayoutType::SWIZZLE_64B: return "SWIZZLE_64B";
|
||||
case LayoutType::SWIZZLE_32B: return "SWIZZLE_32B";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
union SmemDescriptor
|
||||
{
|
||||
uint64_t desc_ = 0;
|
||||
// Bitfield implementation avoids the need for shifts in assignment
|
||||
struct {
|
||||
// start_address, bit [0,14), 4LSB not included
|
||||
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// leading dimension byte offset, bit [16,30), 4LSB not included
|
||||
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
|
||||
// stride dimension byte offset, bit [32,46), 4LSB not included
|
||||
uint16_t stride_byte_offset_ : 14, version_ : 2; // 14 bits [0,14), 2 bits [14,16)
|
||||
// base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53).
|
||||
uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1, : 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused
|
||||
// layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0, SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4, SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5, N/A = 7
|
||||
uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8)
|
||||
};
|
||||
// Seperate the field, as we may only update one part of desc
|
||||
struct {
|
||||
uint32_t lo;
|
||||
uint32_t hi;
|
||||
};
|
||||
|
||||
// Decay to a uint64_t
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator uint64_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
enum class F16F32Format : uint8_t {
|
||||
F16 = 0,
|
||||
BF16 = 1,
|
||||
TF32 = 2,
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(F16F32Format const& t) {
|
||||
switch (t) {
|
||||
case F16F32Format::F16: return "F16";
|
||||
case F16F32Format::BF16: return "BF16";
|
||||
case F16F32Format::TF32: return "TF32";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr F16F32Format to_F16F32Format() {
|
||||
if constexpr (is_same_v<T, half_t>) { return F16F32Format::F16; } else
|
||||
if constexpr (is_same_v<T, bfloat16_t>) { return F16F32Format::BF16; } else
|
||||
if constexpr (is_same_v<T, tfloat32_t>) { return F16F32Format::TF32; } else
|
||||
{ static_assert(sizeof(T) == 0, "Unknown type for F16F32Format"); }
|
||||
}
|
||||
|
||||
enum class S8Format : uint8_t {
|
||||
UINT8 = 0,
|
||||
INT8 = 1,
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(S8Format const& t) {
|
||||
switch (t) {
|
||||
case S8Format::UINT8: return "UINT8";
|
||||
case S8Format::INT8: return "INT8";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr S8Format to_S8Format() {
|
||||
if constexpr (is_same_v<T, uint8_t>) { return S8Format::UINT8; } else
|
||||
if constexpr (is_same_v<T, int8_t>) { return S8Format::INT8; } else
|
||||
{ static_assert(sizeof(T) == 0, "Unknown type for S8Format"); }
|
||||
}
|
||||
|
||||
enum class MXF8F6F4Format : uint8_t {
|
||||
E4M3 = 0,
|
||||
E5M2 = 1,
|
||||
E2M3 = 3,
|
||||
E3M2 = 4,
|
||||
E2M1 = 5,
|
||||
INVALID = 7 // an invalid datatype for runtime proxy type
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(MXF8F6F4Format const& t) {
|
||||
switch (t) {
|
||||
case MXF8F6F4Format::E4M3: return "E4M3";
|
||||
case MXF8F6F4Format::E5M2: return "E5M2";
|
||||
case MXF8F6F4Format::E2M3: return "E2M3";
|
||||
case MXF8F6F4Format::E3M2: return "E3M2";
|
||||
case MXF8F6F4Format::E2M1: return "E2M1";
|
||||
case MXF8F6F4Format::INVALID: return "INVALID";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr MXF8F6F4Format to_MXF8F6F4Format() {
|
||||
if constexpr (is_same_v<T, float_e4m3_t>) { return MXF8F6F4Format::E4M3; } else
|
||||
if constexpr (is_same_v<T, float_e5m2_t>) { return MXF8F6F4Format::E5M2; } else
|
||||
if constexpr (is_same_v<T, detail::float_e2m3_unpacksmem_t>) { return MXF8F6F4Format::E2M3; } else
|
||||
if constexpr (is_same_v<T, detail::float_e3m2_unpacksmem_t>) { return MXF8F6F4Format::E3M2; } else
|
||||
if constexpr (is_same_v<T, detail::float_e2m1_unpacksmem_t>) { return MXF8F6F4Format::E2M1; } else
|
||||
{ static_assert(sizeof(T) == 0, "Unknown type for MXF8F6F4Format"); }
|
||||
}
|
||||
|
||||
enum class MXF4Format : uint8_t {
|
||||
E2M1 = 1,
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(MXF4Format const& t) {
|
||||
switch (t) {
|
||||
case MXF4Format::E2M1: return "E2M1";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr MXF4Format to_MXF4Format() {
|
||||
if constexpr (is_same_v<T, float_e2m1_t>) { return MXF4Format::E2M1; } else
|
||||
{ static_assert(sizeof(T) == 0, "Unknown type for MXF4Format"); }
|
||||
}
|
||||
|
||||
enum class ScaleFormat : uint8_t {
|
||||
UE4M3 = 0,
|
||||
UE8M0 = 1,
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(ScaleFormat const& t) {
|
||||
switch (t) {
|
||||
case ScaleFormat::UE4M3: return "UE4M3";
|
||||
case ScaleFormat::UE8M0: return "UE8M0";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr ScaleFormat to_ScaleFormat() {
|
||||
if constexpr (is_same_v<T, float_ue4m3_t>) { return ScaleFormat::UE4M3; } else
|
||||
if constexpr (is_same_v<T, float_ue8m0_t>) { return ScaleFormat::UE8M0; } else
|
||||
{ static_assert(sizeof(T) == 0, "Unknown type for ScaleFormat"); }
|
||||
}
|
||||
|
||||
enum class CFormat : uint8_t {
|
||||
F16 = 0,
|
||||
F32 = 1,
|
||||
S32 = 2,
|
||||
};
|
||||
|
||||
CUTE_HOST_DEVICE char const* to_string(CFormat const& t) {
|
||||
switch (t) {
|
||||
case CFormat::F16: return "F16";
|
||||
case CFormat::F32: return "F32";
|
||||
case CFormat::S32: return "S32";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
enum class MaxShift : uint8_t {
|
||||
NoShift = 0,
|
||||
MaxShift8 = 1,
|
||||
MaxShift16 = 2,
|
||||
MaxShift32 = 3
|
||||
};
|
||||
|
||||
enum class BMatrixBufferId : uint8_t {
|
||||
Zero = 0u,
|
||||
One = 1u,
|
||||
Two = 2u,
|
||||
Three = 3u
|
||||
};
|
||||
|
||||
enum class BMatrixBufferReuse : uint8_t {
|
||||
Keep = 1u,
|
||||
Reuse = 2u,
|
||||
ReuseAndKeep = 3u
|
||||
};
|
||||
|
||||
// using MaskAndShiftB = uint32_t[2];
|
||||
union MaskAndShiftB
|
||||
{
|
||||
uint32_t uri[2];
|
||||
|
||||
struct {
|
||||
// Bitfield implementation avoids the need for shifts in assignment
|
||||
uint8_t start_count_ [4]; // bit [ 0:32) : 8 bits each. Specifies the start count for mask generation.
|
||||
uint32_t first_span_ : 4, // bit [32:36) : 1 bit each. 0 = start where B is used. 1 = start with where B is skipped(0 value is used).
|
||||
: 3, //
|
||||
nzm_ : 1, // bit [39:40) : 0 = Enable the mask. 1 = Disable the mask.
|
||||
skip_span_ : 8, // bit [40:48) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where 0 value is used.
|
||||
use_span_ : 8, // bit [48:55) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where B matrix data is used.
|
||||
shift_ : 6, // bit [56:62) : Shift value for B matrix data.
|
||||
: 2;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename ShapeType, int FLT_S, int CTA_M, int CTA_N>
|
||||
CUTE_HOST_DEVICE constexpr auto
|
||||
make_column_zero_mask(ShapeType conv_q, int32_t cta_coord_q, int32_t num_pixels_skip_left) {
|
||||
|
||||
static_assert(cute::is_same_v<ShapeType, cutlass::FastDivmod> || cute::is_integral<ShapeType>::value);
|
||||
|
||||
cute::array<MaskAndShiftB, FLT_S> column_zero_masks{};
|
||||
|
||||
static_assert(FLT_S == 3, "Filter size not supported.");
|
||||
constexpr int MAX_USE_SPAN_COUNT = 256;
|
||||
constexpr int MAX_SKIP_SPAN_COUNT = 256;
|
||||
|
||||
// conv_q_int used for non-divmod case (add/minus/..)
|
||||
// conv_q used for divmod case (div/mod/...)
|
||||
int32_t conv_q_int = int(conv_q);
|
||||
auto [_, cta_q] = divmod(cta_coord_q * CTA_N, conv_q);
|
||||
|
||||
int step_q = CTA_M == 128 ? CTA_N / 1
|
||||
: CTA_M == 64 ? CTA_N / 2
|
||||
: CTA_M == 32 ? CTA_N / 4
|
||||
: 0;
|
||||
|
||||
for (int mask_iter = 0; mask_iter < int(CTA_N / step_q); ++mask_iter) {
|
||||
|
||||
for (int s_iter = 0; s_iter < FLT_S; s_iter += 1) {
|
||||
|
||||
int32_t skip_span{0}, use_span{0}, nzm{1}, first_span{0}, start_count{0}, shift{0};
|
||||
|
||||
shift = s_iter;
|
||||
|
||||
// Examples for CZM setting
|
||||
// CASE0: (skip_span_ < 0)
|
||||
// | padding |<- conv_q ->|
|
||||
// |skip_span_|<- use_span ->|skip_span_|
|
||||
// -skip_span 0 ^cta_q conv_q-1
|
||||
// 0 ^index
|
||||
//
|
||||
// CASE1: (skip_span_ > 0)
|
||||
// |<- conv_q ->|
|
||||
// |skip_span_|<- use_span ->|skip_span_|
|
||||
// 0 ^cta_q conv_q-1
|
||||
// 0 ^index
|
||||
//
|
||||
// line 0 an input vector from 0 to conv_q with the padding
|
||||
// line 1 shows the different spans we need to skip or load
|
||||
// lines 2-3 show the different coordinates of different boundaries.
|
||||
// CTQ_q is the coordinate of the present cta.
|
||||
|
||||
int32_t skip_span_ = num_pixels_skip_left - shift;
|
||||
int32_t index{0};
|
||||
if (skip_span_ > 0) {
|
||||
auto [_, index_mod] = divmod(cta_q, conv_q);
|
||||
index = index_mod;
|
||||
} else if (skip_span_ < 0) {
|
||||
auto [_, index_mod] = divmod((cta_q - skip_span_), conv_q);
|
||||
index = index_mod;
|
||||
} else {
|
||||
nzm = 0;
|
||||
}
|
||||
skip_span = cute::max(cute::abs(skip_span_), 1);
|
||||
use_span = cute::min(conv_q_int - static_cast<int32_t>(skip_span), MAX_USE_SPAN_COUNT);
|
||||
if (use_span > 0) {
|
||||
first_span = index >= skip_span ? 0 : 1;
|
||||
if ((first_span == 0) && (index + CTA_N < conv_q_int + skip_span)) {
|
||||
nzm = 0;
|
||||
} else {
|
||||
start_count = first_span == 0 ? (use_span - (conv_q_int - index)) : index;
|
||||
}
|
||||
} else {
|
||||
skip_span = MAX_SKIP_SPAN_COUNT;
|
||||
use_span = 1;
|
||||
first_span = 1;
|
||||
start_count = 0;
|
||||
}
|
||||
|
||||
column_zero_masks[s_iter].start_count_[mask_iter] = start_count;
|
||||
column_zero_masks[s_iter].first_span_ |= first_span << mask_iter;
|
||||
column_zero_masks[s_iter].nzm_ |= nzm;
|
||||
column_zero_masks[s_iter].skip_span_ = skip_span - 1;
|
||||
column_zero_masks[s_iter].use_span_ = use_span - 1;
|
||||
column_zero_masks[s_iter].shift_ = shift;
|
||||
|
||||
}
|
||||
|
||||
cta_q += step_q;
|
||||
}
|
||||
|
||||
return column_zero_masks;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr auto to_UMMAFormat() {
|
||||
if constexpr (is_same_v<T, half_t>) { return F16F32Format::F16; } else
|
||||
if constexpr (is_same_v<T, bfloat16_t>) { return F16F32Format::BF16; } else
|
||||
if constexpr (is_same_v<T, tfloat32_t>) { return F16F32Format::TF32; } else
|
||||
if constexpr (is_same_v<T, uint8_t>) { return S8Format::UINT8; } else
|
||||
if constexpr (is_same_v<T, int8_t>) { return S8Format::INT8; } else
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float8_t>) {return MXF8F6F4Format::INVALID; } else
|
||||
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float6_t>) {return MXF8F6F4Format::INVALID; } else
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float4_t>) {return MXF8F6F4Format::INVALID; } else
|
||||
if constexpr (is_same_v<T, detail::type_erased_dynamic_float4_unpacksmem_t>) {return MXF8F6F4Format::INVALID; } else
|
||||
|
||||
if constexpr (is_same_v<T, float_e4m3_t>) { return MXF8F6F4Format::E4M3; } else
|
||||
if constexpr (is_same_v<T, float_e5m2_t>) { return MXF8F6F4Format::E5M2; } else
|
||||
|
||||
if constexpr (is_same_v<T, detail::type_erased_dynamic_float6_unpacksmem_t>) {return MXF8F6F4Format::INVALID; } else
|
||||
if constexpr (is_same_v<T, detail::float_e2m3_unpacksmem_t>) { return MXF8F6F4Format::E2M3; } else
|
||||
if constexpr (is_same_v<T, detail::float_e3m2_unpacksmem_t>) { return MXF8F6F4Format::E3M2; } else
|
||||
if constexpr (is_same_v<T, float_e2m3_t>) { return MXF8F6F4Format::E2M3; } else
|
||||
if constexpr (is_same_v<T, float_e3m2_t>) { return MXF8F6F4Format::E3M2; } else
|
||||
if constexpr (is_same_v<T, detail::float_e2m1_unpacksmem_t>) { return MXF8F6F4Format::E2M1; } else
|
||||
if constexpr (is_same_v<T, float_e2m1_t>) { return MXF4Format::E2M1; } else
|
||||
|
||||
{ static_assert(sizeof(T) == 0, "Unknown type for UMMAFormat"); }
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr CFormat to_CFormat() {
|
||||
if constexpr (is_same_v<T, half_t>) { return CFormat::F16; } else
|
||||
if constexpr (is_same_v<T, float>) { return CFormat::F32; } else
|
||||
if constexpr (is_same_v<T, int32_t>) { return CFormat::S32; } else
|
||||
{ static_assert(sizeof(T) == 0, "Unknown type for CFormat"); }
|
||||
}
|
||||
|
||||
union InstrDescriptor
|
||||
{
|
||||
uint32_t desc_;
|
||||
|
||||
struct {
|
||||
// Bitfield implementation avoids the need for shifts in assignment
|
||||
uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2
|
||||
sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4
|
||||
saturate_ : 1, // bit [ 3, 4) : 0 = no saturate. 1 = saturate. 1 value valid only for S8
|
||||
c_format_ : 2, // bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32
|
||||
: 1, //
|
||||
a_format_ : 3, // bit [ 7,10) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean
|
||||
b_format_ : 3, // bit [10,13) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean
|
||||
a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format
|
||||
b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format
|
||||
a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats
|
||||
uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats
|
||||
n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats
|
||||
: 1, //
|
||||
m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256)
|
||||
: 1, //
|
||||
max_shift_ : 2; // bit [30,32) : Maximum shift for WS instruction. Encoded as follows: 0 = no shift, 1 = maximum shift of 8, 2 = maximum shift of 16, 3 = maximum shift of 32.
|
||||
};
|
||||
|
||||
// Decay to a uint32_t
|
||||
CUTE_HOST_DEVICE constexpr explicit
|
||||
operator uint32_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
union InstrDescriptorBlockScaled
|
||||
{
|
||||
uint32_t desc_;
|
||||
|
||||
struct {
|
||||
// Bitfield implementation avoids the need for shifts in assignment
|
||||
uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2
|
||||
sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4
|
||||
: 1, //
|
||||
b_sf_id_ : 2, // bit [ 4, 6) : Matrix B Scale Factor ID
|
||||
: 1, //
|
||||
a_format_ : 3, // bit [ 7, 9) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean
|
||||
b_format_ : 3, // bit [10,12) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean
|
||||
a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format
|
||||
b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format
|
||||
a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats
|
||||
uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats
|
||||
n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats
|
||||
scale_format_ : 1, // bit [23,24) : 0=E4M3, 1=E8M0
|
||||
m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256)
|
||||
a_sf_id_ : 2, // bit [29,31) : Matrix A Scale Factor ID
|
||||
: 1; //
|
||||
};
|
||||
|
||||
// Decay to a uint32_t
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
operator uint32_t() const noexcept { return desc_; }
|
||||
};
|
||||
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
|
||||
UMMA::Saturate c_sat = UMMA::Saturate::False,
|
||||
bool is_sparse = false,
|
||||
UMMA::MaxShift max_shift = UMMA::MaxShift::NoShift>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
UMMA::InstrDescriptor
|
||||
make_instr_desc()
|
||||
{
|
||||
UMMA::InstrDescriptor desc_i = {};
|
||||
|
||||
desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat<a_type>());
|
||||
desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat<b_type>());
|
||||
desc_i.c_format_ = uint8_t(UMMA::to_CFormat<c_type>());
|
||||
|
||||
desc_i.m_dim_ = (M >> 4);
|
||||
desc_i.n_dim_ = (N >> 3);
|
||||
|
||||
desc_i.a_major_ = uint8_t(a_major);
|
||||
desc_i.b_major_ = uint8_t(b_major);
|
||||
|
||||
desc_i.a_negate_ = uint8_t(a_neg);
|
||||
desc_i.b_negate_ = uint8_t(b_neg);
|
||||
desc_i.saturate_ = uint8_t(c_sat);
|
||||
|
||||
desc_i.sparse_flag_ = is_sparse; // 1 = Sparse
|
||||
desc_i.sparse_id2_ = 0;
|
||||
|
||||
desc_i.max_shift_ = uint8_t(max_shift);
|
||||
|
||||
return desc_i;
|
||||
}
|
||||
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
|
||||
UMMA::Saturate c_sat = UMMA::Saturate::False,
|
||||
bool is_sparse = false,
|
||||
UMMA::MaxShift max_shift = UMMA::MaxShift::NoShift>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr uint64_t
|
||||
make_runtime_instr_desc(uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u) {
|
||||
UMMA::InstrDescriptor desc_i = UMMA::make_instr_desc<
|
||||
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat, is_sparse,
|
||||
max_shift>();
|
||||
|
||||
if constexpr (is_sparse) {
|
||||
desc_i.sparse_id2_ = sparse_id2;
|
||||
}
|
||||
else {
|
||||
assert(sparse_id2 == 0u);
|
||||
}
|
||||
// In current compiler exposure, idescE is a uint64_t. It should contain:
|
||||
// - Lower 32b URe: Specifies the tmem address that stores the sparse metadata.
|
||||
// Only needed for Sparse MMA instructions. Otherwise, ignored.
|
||||
// - Upper 32b URh: Specifies the instruction descriptor.
|
||||
uint64_t idescE = (static_cast<uint64_t>(static_cast<uint32_t>(desc_i)) << 32);
|
||||
|
||||
return idescE;
|
||||
}
|
||||
|
||||
template <bool is_sparse = false>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr uint64_t
|
||||
make_runtime_instr_desc(UMMA::InstrDescriptor desc_i, uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u)
|
||||
{
|
||||
if constexpr (is_sparse) {
|
||||
desc_i.sparse_id2_ = sparse_id2;
|
||||
}
|
||||
else {
|
||||
assert(sparse_id2 == 0u);
|
||||
}
|
||||
// In current compiler exposure, idescE is a uint64_t. It should contain:
|
||||
// - Lower 32b URe: Specifies the tmem address that stores the sparse metadata.
|
||||
// Only needed for Sparse MMA instructions. Otherwise, ignored.
|
||||
// - Upper 32b URh: Specifies the instruction descriptor.
|
||||
uint64_t idescE = (static_cast<uint64_t>(static_cast<uint32_t>(desc_i)) << 32);
|
||||
|
||||
return idescE;
|
||||
}
|
||||
|
||||
template <class a_type, class b_type, class c_type, class sf_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
|
||||
bool is_sparse = false>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
UMMA::InstrDescriptorBlockScaled
|
||||
make_instr_desc_block_scaled()
|
||||
{
|
||||
UMMA::InstrDescriptorBlockScaled desc_i = {};
|
||||
|
||||
desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat<a_type>());
|
||||
desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat<b_type>());
|
||||
|
||||
desc_i.scale_format_ = uint8_t(UMMA::to_ScaleFormat<sf_type>());
|
||||
desc_i.a_sf_id_ = 0;
|
||||
desc_i.b_sf_id_ = 0;
|
||||
|
||||
desc_i.m_dim_ = (M >> 4);
|
||||
desc_i.n_dim_ = (N >> 3);
|
||||
|
||||
desc_i.a_major_ = uint8_t(a_major);
|
||||
desc_i.b_major_ = uint8_t(b_major);
|
||||
|
||||
desc_i.a_negate_ = uint8_t(a_neg);
|
||||
desc_i.b_negate_ = uint8_t(b_neg);
|
||||
desc_i.sparse_flag_ = is_sparse; // 1 = Sparse
|
||||
desc_i.sparse_id2_ = 0;
|
||||
|
||||
// Below would bring some warnings.
|
||||
#if defined(__GNUC__)
|
||||
# pragma GCC diagnostic ignored "-Wconversion"
|
||||
#endif
|
||||
return desc_i;
|
||||
}
|
||||
|
||||
template <class a_type, class b_type, class c_type, class sf_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
|
||||
bool is_sparse = false>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr uint64_t
|
||||
make_runtime_instr_desc_block_scaled(uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr,
|
||||
uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u)
|
||||
{
|
||||
UMMA::InstrDescriptorBlockScaled desc_i = UMMA::make_instr_desc_block_scaled<
|
||||
a_type, b_type, c_type, sf_type, M, N,
|
||||
a_major, b_major,
|
||||
a_neg, b_neg,
|
||||
is_sparse>();
|
||||
|
||||
// The first 2-bits of TMEM address includes byte address.
|
||||
desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30;
|
||||
desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30;
|
||||
|
||||
if constexpr (is_sparse) {
|
||||
desc_i.sparse_id2_ = sparse_id2;
|
||||
}
|
||||
else {
|
||||
assert(sparse_id2 == 0u);
|
||||
}
|
||||
|
||||
// In current compiler exposure, idescE is a uint64_t. It should contain:
|
||||
// - Lower 32b URe: Specifies the tmem address that stores the sparse metadata.
|
||||
// Only needed for Sparse MMA instructions. Otherwise, ignored.
|
||||
// - Upper 32b URh: Specifies the instruction descriptor.
|
||||
uint64_t idescE = (static_cast<uint64_t>(static_cast<uint32_t>(desc_i)) << 32);
|
||||
|
||||
return idescE;
|
||||
}
|
||||
|
||||
template <bool is_sparse = false>
|
||||
CUTE_HOST_DEVICE
|
||||
constexpr uint64_t
|
||||
make_runtime_instr_desc_block_scaled(UMMA::InstrDescriptorBlockScaled desc_i,
|
||||
uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr,
|
||||
uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u)
|
||||
{
|
||||
// The first 2-bits of TMEM address includes byte address.
|
||||
desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30;
|
||||
desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30;
|
||||
|
||||
if constexpr (is_sparse) {
|
||||
desc_i.sparse_id2_ = sparse_id2;
|
||||
}
|
||||
else {
|
||||
assert(sparse_id2 == 0u);
|
||||
}
|
||||
|
||||
// In current compiler exposure, idescE is a uint64_t. It should contain:
|
||||
// - Lower 32b URe: Specifies the tmem address that stores the sparse metadata.
|
||||
// Only needed for Sparse MMA instructions. Otherwise, ignored.
|
||||
// - Upper 32b URh: Specifies the instruction descriptor.
|
||||
uint64_t idescE = (static_cast<uint64_t>(static_cast<uint32_t>(desc_i)) << 32);
|
||||
|
||||
return idescE;
|
||||
}
|
||||
|
||||
} // end namespace UMMA
|
||||
} // namespace cute
|
||||
1074
include/cute/arch/mma_sm100_umma.hpp
Normal file
1074
include/cute/arch/mma_sm100_umma.hpp
Normal file
File diff suppressed because it is too large
Load Diff
96
include/cute/arch/simd_sm100.hpp
Normal file
96
include/cute/arch/simd_sm100.hpp
Normal file
@ -0,0 +1,96 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/arch/config.hpp>
|
||||
#include <cute/numeric/real.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
add(float2 & c,
|
||||
float2 const& a,
|
||||
float2 const& b)
|
||||
{
|
||||
#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED)
|
||||
asm volatile("add.f32x2 %0, %1, %2;\n"
|
||||
: "=l"(reinterpret_cast<uint64_t &>(c))
|
||||
: "l"(reinterpret_cast<uint64_t const&>(a)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(b)));
|
||||
#else
|
||||
add(c.x, a.x, b.x);
|
||||
add(c.y, a.y, b.y);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
mul(float2 & c,
|
||||
float2 const& a,
|
||||
float2 const& b)
|
||||
{
|
||||
#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED)
|
||||
asm volatile("mul.f32x2 %0, %1, %2;\n"
|
||||
: "=l"(reinterpret_cast<uint64_t &>(c))
|
||||
: "l"(reinterpret_cast<uint64_t const&>(a)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(b)));
|
||||
#else
|
||||
mul(c.x, a.x, b.x);
|
||||
mul(c.y, a.y, b.y);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
fma(float2 & d,
|
||||
float2 const& a,
|
||||
float2 const& b,
|
||||
float2 const& c)
|
||||
{
|
||||
#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED)
|
||||
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
|
||||
: "=l"(reinterpret_cast<uint64_t &>(d))
|
||||
: "l"(reinterpret_cast<uint64_t const&>(a)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(b)),
|
||||
"l"(reinterpret_cast<uint64_t const&>(c)));
|
||||
#else
|
||||
fma(d.x, a.x, b.x, c.x);
|
||||
fma(d.y, a.y, b.y, c.y);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
168
include/cute/arch/tmem_allocator_sm100.hpp
Normal file
168
include/cute/arch/tmem_allocator_sm100.hpp
Normal file
@ -0,0 +1,168 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/config.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/atom/copy_traits_sm100.hpp>
|
||||
|
||||
#include <cutlass/pipeline/sm90_pipeline.hpp>
|
||||
|
||||
namespace cute::TMEM {
|
||||
|
||||
// All operations of this class require that only a single warp uniformly participates
|
||||
class Allocator1Sm {
|
||||
public:
|
||||
static constexpr int ColumnsPerAllocationSlice = 32;
|
||||
static constexpr int Sm100TmemCapacityColumns = 512;
|
||||
|
||||
__device__ Allocator1Sm() { }
|
||||
|
||||
/**
|
||||
* Performs a non-blocking allocation of TMEM.
|
||||
* @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2.
|
||||
* @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to.
|
||||
* @pre Must be issued by a single fully active warp of the CTA.
|
||||
* @pre Must never be issued by more than one warp at the same time.
|
||||
* @pre For repeated allocations, the same warp must be used to issue all allocations.
|
||||
**/
|
||||
__device__ void
|
||||
allocate(int num_columns, uint32_t* dst_ptr) {
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr);
|
||||
asm volatile(
|
||||
"tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:
|
||||
: "r"(dst_intptr), "r"(num_columns));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__
|
||||
void
|
||||
free(uint32_t tmem_ptr, int num_columns) {
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(tmem_ptr), "r"(num_columns));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void
|
||||
release_allocation_lock() {
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;" ::);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Allocator2Sm {
|
||||
public:
|
||||
static constexpr int ColumnsPerAllocationSlice = 32;
|
||||
static constexpr int Sm100TmemCapacityColumns = 512;
|
||||
|
||||
__device__ Allocator2Sm() { }
|
||||
|
||||
/**
|
||||
* Performs a non-blocking allocation of TMEM.
|
||||
* @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2.
|
||||
* @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to.
|
||||
* Both CTAs _must_ provide the exact same dst_ptr for correctness.
|
||||
* @pre Must be issued by a single fully active warp of the CTA.
|
||||
* @pre Must never be issued by more than one warp at the same time.
|
||||
* @pre For repeated allocations, the same warp must be used to issue all allocations.
|
||||
* @pre The 2 warps from participating CTAs have the same logical warp ID.
|
||||
**/
|
||||
__device__ void
|
||||
allocate(int num_columns, uint32_t* dst_ptr) {
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr);
|
||||
asm volatile(
|
||||
"tcgen05.alloc.cta_group::2.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:
|
||||
: "r"(dst_intptr), "r"(num_columns));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Frees the TMEM corresponding to the pointer and slice count provided.
|
||||
* Release the TMEM after checking that the CTA issuing the free does indeed own the corresponding slices.
|
||||
* @param tmem_ptr Base address of the TMEM address space being freed.
|
||||
* @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2.
|
||||
* @pre Must be issued by a single fully active warp of the CTA.
|
||||
* @pre Must never be issued by more than one warp at the same time.
|
||||
* @pre The 2 warps from participating CTAs have the same logical warp ID.
|
||||
* @returns true
|
||||
**/
|
||||
__device__
|
||||
void
|
||||
free(uint32_t tmem_ptr, int num_columns) {
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1; \n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(tmem_ptr), "r"(num_columns));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__
|
||||
void
|
||||
release_allocation_lock() {
|
||||
#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED)
|
||||
asm volatile("tcgen05.relinquish_alloc_permit.cta_group::2.sync.aligned;" ::);
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cute::TMEM
|
||||
@ -751,14 +751,33 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and
|
||||
#include <cute/atom/copy_traits_sm75.hpp>
|
||||
#include <cute/atom/copy_traits_sm80.hpp>
|
||||
#include <cute/atom/copy_traits_sm90.hpp>
|
||||
#include <cute/atom/copy_traits_sm100.hpp>
|
||||
|
||||
|
||||
// Config
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED
|
||||
# define CUTE_COPY_ATOM_TMA_SM100_ENABLED
|
||||
#endif
|
||||
|
||||
|
||||
#if (!defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED))
|
||||
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if (!defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED))
|
||||
# define CUTE_COPY_ATOM_TMA_SM100_ENABLED
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||
#include <cute/atom/copy_traits_sm90_tma.hpp>
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED)
|
||||
#include <cute/atom/copy_traits_sm100_tma.hpp>
|
||||
#endif
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
3797
include/cute/atom/copy_traits_sm100.hpp
Normal file
3797
include/cute/atom/copy_traits_sm100.hpp
Normal file
File diff suppressed because it is too large
Load Diff
488
include/cute/atom/copy_traits_sm100_im2col.hpp
Normal file
488
include/cute/atom/copy_traits_sm100_im2col.hpp
Normal file
@ -0,0 +1,488 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/*! \file
|
||||
\brief im2col make_tma_copy
|
||||
|
||||
*/
|
||||
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90_desc.hpp"
|
||||
#include "cute/atom/copy_traits_sm90_im2col.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cute {
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_OP : SM100_TMA_2SM_LOAD_IM2COL {};
|
||||
|
||||
/// @brief Non-executable specialization of Copy_Traits for SM100
|
||||
/// im2col TMA load, with TMA descriptor but no barrier.
|
||||
///
|
||||
/// Use `.with(memory_barrier)` to construct an executable version.
|
||||
template <class NumBitsPerTMA, class TMATensor>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD_IM2COL, NumBitsPerTMA, TMATensor>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
Im2ColTmaDescriptor tma_desc_;
|
||||
TMATensor tma_tensor_;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Im2ColTmaDescriptor const*
|
||||
get_tma_descriptor() const
|
||||
{
|
||||
return &tma_desc_;
|
||||
}
|
||||
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TMATensor const
|
||||
get_tma_tensor(GShape const&) const
|
||||
{
|
||||
return tma_tensor_;
|
||||
}
|
||||
|
||||
/// @brief Get an executable specialization.
|
||||
///
|
||||
/// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL are not
|
||||
/// directly executable. Instead, call this "with" member function
|
||||
/// to get an executable specialization. "Executable" means that
|
||||
/// @c copy_unpack works.
|
||||
///
|
||||
/// @param tma_mbar Memory barrier for synchronization
|
||||
///
|
||||
/// @param multicast_mask Multicast mask (unused; only exists
|
||||
/// for consistency with the actual multicast Copy_Traits
|
||||
/// specialization)
|
||||
///
|
||||
/// @return Executable specialization of @c Copy_Traits
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM100_TMA_2SM_LOAD_IM2COL_OP, NumBitsPerTMA>
|
||||
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const
|
||||
{
|
||||
return {{}, {&tma_desc_, &tma_mbar}};
|
||||
}
|
||||
|
||||
// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL
|
||||
// are not directly executable. Instead, call .with
|
||||
// to get an executable specialization.
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) = delete;
|
||||
};
|
||||
|
||||
/// TMA load, with TMA descriptor and barrier.
|
||||
template <class NumBitsPerTMA>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD_IM2COL_OP, NumBitsPerTMA>
|
||||
: TMA_LOAD_IM2COL_Unpack<SM100_TMA_2SM_LOAD_IM2COL_OP>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM100_TMA_2SM_LOAD_IM2COL arguments
|
||||
tuple<
|
||||
Im2ColTmaDescriptor const*,
|
||||
uint64_t* // smem mbarrier
|
||||
> const opargs_;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_OP : SM100_TMA_2SM_LOAD_IM2COL_MULTICAST {};
|
||||
|
||||
/// @brief Non-executable specialization of Copy_Traits for SM100
|
||||
/// im2col TMA load, with TMA descriptor but no barrier or multicast
|
||||
/// mask.
|
||||
///
|
||||
/// Use `.with(memory_barrier)` to construct an executable version.
|
||||
template <class NumBitsPerTMA, class TMATensor>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD_IM2COL_MULTICAST, NumBitsPerTMA, TMATensor>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
Im2ColTmaDescriptor tma_desc_;
|
||||
TMATensor tma_tensor_;
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Im2ColTmaDescriptor const*
|
||||
get_tma_descriptor() const
|
||||
{
|
||||
return &tma_desc_;
|
||||
}
|
||||
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TMATensor const
|
||||
get_tma_tensor(GShape const&) const
|
||||
{
|
||||
return tma_tensor_;
|
||||
}
|
||||
|
||||
/// @brief Get an executable specialization.
|
||||
///
|
||||
/// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL_MULTICAST
|
||||
/// are not directly executable. Instead, call this "with" member
|
||||
/// function to get an executable specialization. "Executable"
|
||||
/// means that @c copy_unpack works.
|
||||
///
|
||||
/// @param tma_mbar Memory barrier for synchronization
|
||||
///
|
||||
/// @param multicast_mask Multicast mask (defaults to a single CTA)
|
||||
///
|
||||
/// @return Executable specialization of @c Copy_Traits
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_OP, NumBitsPerTMA>
|
||||
with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const
|
||||
{
|
||||
return {{}, {&tma_desc_, &tma_mbar, multicast_mask}};
|
||||
}
|
||||
|
||||
// Copy_Traits specializations with SM100_TMA_LOAD_IM2COL_MULTICAST
|
||||
// are not directly executable. Instead, call .with to get an
|
||||
// executable specialization.
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) = delete;
|
||||
};
|
||||
|
||||
/// @brief Executable specialization of Copy_Traits for SM100 multicast
|
||||
/// im2col TMA load, with TMA descriptor, barrier, and multicast mask.
|
||||
template <class NumBitsPerTMA>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_OP, NumBitsPerTMA>
|
||||
: TMA_LOAD_IM2COL_Unpack<SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_OP>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit.
|
||||
using SrcLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2, NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM100_TMA_2SM_LOAD_IM2COL_MULTICAST arguments
|
||||
tuple<
|
||||
Im2ColTmaDescriptor const*,
|
||||
uint64_t*, // smem mbarrier
|
||||
uint16_t // multicast mask
|
||||
> const opargs_;
|
||||
};
|
||||
|
||||
////////////////////////////////////
|
||||
// Make TMA
|
||||
///////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/** Make a CuTe CTA-collective TiledCopy for a TMA operation.
|
||||
*
|
||||
* @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD
|
||||
* @param gtensor The GMEM Tensor to be involved in the TMA.
|
||||
* @param slayout The SMEM Layout to be involved in the TMA.
|
||||
* @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with.
|
||||
* This is often the cluster_tile_shape that is used to tile the GMEM:
|
||||
* local_tile(gtensor, cluster_tile_shape, cluster_coord)
|
||||
* -> Cluster-local tile of GMEM
|
||||
* @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning.
|
||||
*
|
||||
* This code attempts to maximize the TMA box size. It does this by tracing
|
||||
* the SMEM "vector" -- the inverse of the smem layout -- to find the largest
|
||||
* contiguous array of smem that can be written to/from global memory given
|
||||
* the constraints that the TMA instruction imposes.
|
||||
*
|
||||
* This is accomplished by assigning "basis" strides to the GMEM to track which
|
||||
* modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according
|
||||
* to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc.
|
||||
*
|
||||
* Examples:
|
||||
*/
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class Cluster_Tile,
|
||||
class... Args,
|
||||
class LowerCornerStride,
|
||||
class UpperCornerStride,
|
||||
class LowerPaddingStride,
|
||||
class UpperPaddingStride,
|
||||
class TraversalStride,
|
||||
class LowerSRTStride,
|
||||
class DilationStride>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_im2col_tma_copy_A_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (M,K,...)
|
||||
SLayout const& slayout, // (MMA, MMA_M, MMA_K)
|
||||
Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K)
|
||||
TiledMMA<Args...> const& mma,
|
||||
LowerCornerStride const& lower_corner_whd,
|
||||
UpperCornerStride const& upper_corner_whd,
|
||||
LowerPaddingStride const& lower_padding_whd,
|
||||
UpperPaddingStride const& upper_padding_whd,
|
||||
TraversalStride const& stride_whd,
|
||||
LowerSRTStride const& lower_srt,
|
||||
DilationStride const& stride_srt,
|
||||
TMA::DescriptorAuxParams const& aux_params = {})
|
||||
{
|
||||
constexpr int R = GLayout::rank;
|
||||
// Keep only MK modes from MNK
|
||||
auto cluster_tile_shape = append<R>(make_shape(get<0>(cluster_tile), get<2>(cluster_tile)), Int<1>{});
|
||||
auto cluster_layout = make_identity_layout(cluster_tile_shape);
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat<R>(_));
|
||||
|
||||
auto cta_t_vmnk_strides = [](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL_MULTICAST>) {
|
||||
return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL>) {
|
||||
return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
auto cta_t_shape = shape(mma.get_thr_layout_vmnk());
|
||||
// cta rank -> logical cta idx
|
||||
auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides));
|
||||
|
||||
return detail::make_tma_copy_im2col(copy_op, gtensor, slayout,
|
||||
cta_t_map, cta_v_tile,
|
||||
lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd,
|
||||
lower_srt, stride_srt, aux_params);
|
||||
}
|
||||
|
||||
template <class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class Cluster_Tile,
|
||||
class... Args,
|
||||
class LowerCornerStride,
|
||||
class UpperCornerStride,
|
||||
class LowerPaddingStride,
|
||||
class UpperPaddingStride,
|
||||
class TraversalStride,
|
||||
class LowerSRTStride,
|
||||
class DilationStride>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_im2col_tma_copy_B_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (N,K,...)
|
||||
SLayout const& slayout, // (MMA, MMA_N, MMA_K)
|
||||
Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K)
|
||||
TiledMMA<Args...> const& mma,
|
||||
LowerCornerStride const& lower_corner_whd,
|
||||
UpperCornerStride const& upper_corner_whd,
|
||||
LowerPaddingStride const& lower_padding_whd,
|
||||
UpperPaddingStride const& upper_padding_whd,
|
||||
TraversalStride const& stride_whd,
|
||||
LowerSRTStride const& lower_srt,
|
||||
DilationStride const& stride_srt,
|
||||
TMA::DescriptorAuxParams const& aux_params = {})
|
||||
{
|
||||
constexpr int R = GLayout::rank;
|
||||
// Keep only NK modes from MNK
|
||||
auto cluster_tile_shape = append<R>(make_shape(get<1>(cluster_tile), get<2>(cluster_tile)), Int<1>{});
|
||||
auto cluster_layout = make_identity_layout(cluster_tile_shape);
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat<R>(_));
|
||||
|
||||
auto cta_t_vmnk_strides = [](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL_MULTICAST>) {
|
||||
return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL>) {
|
||||
return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
auto cta_t_shape = shape(mma.get_thr_layout_vmnk());
|
||||
// cta rank -> logical cta idx
|
||||
auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides));
|
||||
|
||||
return detail::make_tma_copy_im2col(copy_op, gtensor, slayout,
|
||||
cta_t_map, cta_v_tile,
|
||||
lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd,
|
||||
lower_srt, stride_srt, aux_params);
|
||||
}
|
||||
|
||||
/////////////////////////////////////
|
||||
// Experimental Make Im2col TMA Atom
|
||||
/////////////////////////////////////
|
||||
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class MMA_Tiler,
|
||||
class... Args,
|
||||
class ClusterShapeVMNK,
|
||||
class LowerCornerStride,
|
||||
class UpperCornerStride,
|
||||
class LowerPaddingStride,
|
||||
class UpperPaddingStride,
|
||||
class TraversalStride,
|
||||
class LowerSRTStride,
|
||||
class DilationStride>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_im2col_tma_atom_A_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (M, K, ...)
|
||||
SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...)
|
||||
MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...)
|
||||
TiledMMA<Args...> const& mma,
|
||||
ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K)
|
||||
LowerCornerStride const& lower_corner_whd,
|
||||
UpperCornerStride const& upper_corner_whd,
|
||||
LowerPaddingStride const& lower_padding_whd,
|
||||
UpperPaddingStride const& upper_padding_whd,
|
||||
TraversalStride const& stride_whd,
|
||||
LowerSRTStride const& lower_srt,
|
||||
DilationStride const& stride_srt,
|
||||
TMA::DescriptorAuxParams const& aux_params = {})
|
||||
{
|
||||
constexpr int R = GLayout::rank;
|
||||
// Keep only MK modes from MNK
|
||||
auto cluster_tile_shape = append<R>(make_shape(get<0>(mma_tiler), get<2>(mma_tiler)), Int<1>{});
|
||||
auto cluster_layout = make_identity_layout(cluster_tile_shape);
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat<R>(_));
|
||||
|
||||
// The size of the multicasting
|
||||
auto num_multicast = [&](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL_MULTICAST>) {
|
||||
return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL> ||
|
||||
is_same_v<CopyOp, SM90_TMA_STORE_IM2COL> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL>) {
|
||||
return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile,
|
||||
lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd,
|
||||
stride_whd, lower_srt, stride_srt, aux_params);
|
||||
}
|
||||
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class MMA_Tiler,
|
||||
class... Args,
|
||||
class ClusterShapeVMNK,
|
||||
class LowerCornerStride,
|
||||
class UpperCornerStride,
|
||||
class LowerPaddingStride,
|
||||
class UpperPaddingStride,
|
||||
class TraversalStride,
|
||||
class LowerSRTStride,
|
||||
class DilationStride>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_im2col_tma_atom_B_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (N, K, ...)
|
||||
SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...)
|
||||
MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...)
|
||||
TiledMMA<Args...> const& mma,
|
||||
ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K)
|
||||
LowerCornerStride const& lower_corner_whd,
|
||||
UpperCornerStride const& upper_corner_whd,
|
||||
LowerPaddingStride const& lower_padding_whd,
|
||||
UpperPaddingStride const& upper_padding_whd,
|
||||
TraversalStride const& stride_whd,
|
||||
LowerSRTStride const& lower_srt,
|
||||
DilationStride const& stride_srt,
|
||||
TMA::DescriptorAuxParams const& aux_params = {})
|
||||
{
|
||||
constexpr int R = GLayout::rank;
|
||||
// Keep only NK modes from MNK
|
||||
auto cluster_tile_shape = append<R>(make_shape(get<1>(mma_tiler), get<2>(mma_tiler)), Int<1>{});
|
||||
auto cluster_layout = make_identity_layout(cluster_tile_shape);
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat<R>(_));
|
||||
|
||||
// The size of the multicasting
|
||||
auto num_multicast = [&](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL_MULTICAST>) {
|
||||
return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_IM2COL> ||
|
||||
is_same_v<CopyOp, SM90_TMA_STORE_IM2COL> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_IM2COL>) {
|
||||
return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile,
|
||||
lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd,
|
||||
stride_whd, lower_srt, stride_srt, aux_params);
|
||||
}
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
} // end namespace cute
|
||||
487
include/cute/atom/copy_traits_sm100_tma.hpp
Normal file
487
include/cute/atom/copy_traits_sm100_tma.hpp
Normal file
@ -0,0 +1,487 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2021 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/atom/copy_traits_sm90_tma.hpp>
|
||||
#include <cute/arch/copy_sm100_tma.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// TMA_LOAD ////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_OP : SM100_TMA_2SM_LOAD {};
|
||||
|
||||
// The non-executable SM100_TMA_2SM_LOAD with tma_desc and no tma_mbar
|
||||
// Use .with(tma_mbar) to construct an executable version
|
||||
template <class NumBitsPerTMA, class AuxParams_>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD, NumBitsPerTMA, AuxParams_>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM100_TMA_2SM_LOAD arguments
|
||||
TmaDescriptor tma_desc_;
|
||||
using AuxParams = AuxParams_;
|
||||
AuxParams aux_params_;
|
||||
|
||||
// Return TmaDescriptor/TensorMap
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TmaDescriptor const*
|
||||
get_tma_descriptor() const {
|
||||
return &tma_desc_;
|
||||
}
|
||||
|
||||
// Construct an executable SM100_TMA_2SM_LOAD with tma_mbar
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
|
||||
with(
|
||||
uint64_t& tma_mbar,
|
||||
[[maybe_unused]] uint16_t const& multicast_mask = 0,
|
||||
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
|
||||
// We accept multicast_mask here to keep the API for both atoms consistent
|
||||
return {{}, {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)}};
|
||||
}
|
||||
|
||||
// Construct an executable SM100_TMA_2SM_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
|
||||
with(
|
||||
TmaDescriptor const* new_tma_desc,
|
||||
uint64_t& tma_mbar,
|
||||
[[maybe_unused]] uint16_t const& multicast_mask = 0,
|
||||
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
|
||||
// We accept multicast_mask here to keep the API for both atoms consistent
|
||||
return {{}, {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)}};
|
||||
}
|
||||
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_tma_tensor(GShape const& g_shape) const {
|
||||
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
|
||||
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
|
||||
}
|
||||
|
||||
// Don't try to execute a copy with SM100_TMA_2SM_LOAD before calling .with()
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) = delete;
|
||||
};
|
||||
|
||||
// The executable SM100_TMA_2SM_LOAD with tma_desc and tma_mbar
|
||||
template <class NumBitsPerTMA>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
|
||||
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM100_TMA_2SM_LOAD arguments
|
||||
tuple<
|
||||
TmaDescriptor const*,
|
||||
uint64_t*, // smem mbarrier
|
||||
uint64_t // cache hint
|
||||
> const opargs_;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SM100_TMA_2SM_LOAD_MULTICAST_OP : SM100_TMA_2SM_LOAD_MULTICAST {};
|
||||
|
||||
template <class NumBitsPerTMA, class AuxParams_>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM100_TMA_2SM_LOAD_MULTICAST_OP arguments
|
||||
TmaDescriptor tma_desc_;
|
||||
using AuxParams = AuxParams_;
|
||||
AuxParams aux_params_;
|
||||
|
||||
// Return TmaDescriptor/TensorMap
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
TmaDescriptor const*
|
||||
get_tma_descriptor() const {
|
||||
return &tma_desc_;
|
||||
}
|
||||
|
||||
// Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
|
||||
with(
|
||||
uint64_t& tma_load_mbar,
|
||||
uint16_t const& multicast_mask,
|
||||
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
|
||||
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)}};
|
||||
}
|
||||
|
||||
// Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
|
||||
with(
|
||||
TmaDescriptor const* new_tma_desc,
|
||||
uint64_t& tma_load_mbar,
|
||||
uint16_t const& multicast_mask,
|
||||
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
|
||||
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)}};
|
||||
}
|
||||
|
||||
template <class GShape>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
get_tma_tensor(GShape const& g_shape) const {
|
||||
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
|
||||
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
|
||||
}
|
||||
|
||||
// Don't try to execute a copy with SM100_TMA_2SM_LOAD_MULTICAST_OP before calling .with()
|
||||
template <class TS, class SLayout,
|
||||
class TD, class DLayout>
|
||||
CUTE_HOST_DEVICE friend constexpr void
|
||||
copy_unpack(Copy_Traits const& traits,
|
||||
Tensor<TS,SLayout> const& src,
|
||||
Tensor<TD,DLayout> & dst) = delete;
|
||||
};
|
||||
|
||||
template <class NumBitsPerTMA>
|
||||
struct Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
|
||||
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
|
||||
{
|
||||
using ThrID = Layout<_2>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
|
||||
// SM100_TMA_2SM_LOAD_MULTICAST_OP arguments
|
||||
tuple<
|
||||
TmaDescriptor const*,
|
||||
uint64_t*, // smem mbarrier
|
||||
uint16_t, // multicast mask
|
||||
uint64_t // cache hint
|
||||
> const opargs_;
|
||||
};
|
||||
|
||||
////////////////////////////////////
|
||||
// Make TMA
|
||||
///////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/** Make a CuTe CTA-collective TiledCopy for a TMA operation.
|
||||
*
|
||||
* @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD
|
||||
* @param gtensor The GMEM Tensor to be involved in the TMA.
|
||||
* @param slayout The SMEM Layout to be involved in the TMA.
|
||||
* @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with.
|
||||
* This is often the cluster_tile_shape that is used to tile the GMEM:
|
||||
* local_tile(gtensor, cluster_tile_shape, cluster_coord)
|
||||
* -> Cluster-local tile of GMEM
|
||||
* @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning.
|
||||
*
|
||||
* This code attempts to maximize the TMA box size. It does this by tracing
|
||||
* the SMEM "vector" -- the inverse of the smem layout -- to find the largest
|
||||
* contiguous array of smem that can be written to/from global memory given
|
||||
* the constraints that the TMA instruction imposes.
|
||||
*
|
||||
* This is accomplished by assigning "basis" strides to the GMEM to track which
|
||||
* modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according
|
||||
* to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc.
|
||||
*
|
||||
* Examples:
|
||||
*/
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class Cluster_Tiler,
|
||||
class... Args>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_copy_A_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (M, K, ...)
|
||||
SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...)
|
||||
Cluster_Tiler const& cluster_tiler, // (TILER_M, TILER_N, TILER_K, ...)
|
||||
TiledMMA<Args...> const& mma)
|
||||
{
|
||||
// Keep only MK modes from MNK
|
||||
auto cluster_tiler_mk = remove<1>(cluster_tiler);
|
||||
// cluster tile coord -> gtensor coord
|
||||
auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mk); // (TILE_M, TILE_K, ...)
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_M, MMA_K, ...)
|
||||
|
||||
auto cta_t_vmnk_strides = [](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
|
||||
return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
|
||||
is_same_v<CopyOp, SM90_TMA_STORE> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
|
||||
return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
auto cta_t_shape = shape(mma.get_thr_layout_vmnk());
|
||||
// cta rank -> logical cta idx
|
||||
auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)));
|
||||
|
||||
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
|
||||
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
|
||||
return detail::make_tma_copy_tiled<TmaType>(copy_op, gtensor, slayout, cta_t_map, cta_v_tile);
|
||||
}
|
||||
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class Cluster_Tiler,
|
||||
class... Args>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_copy_B_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (N, K, ...)
|
||||
SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...)
|
||||
Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...)
|
||||
TiledMMA<Args...> const& mma)
|
||||
{
|
||||
// Keep only NK modes from MNK
|
||||
auto cluster_tiler_nk = remove<0>(cluster_tiler);
|
||||
// cluster tile coord -> gtensor coord
|
||||
auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_nk); // (TILE_N, TILE_K, ...)
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_N, MMA_K, ...)
|
||||
|
||||
auto cta_t_vmnk_strides = [](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
|
||||
return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
|
||||
is_same_v<CopyOp, SM90_TMA_STORE> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
|
||||
return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
auto cta_t_shape = shape(mma.get_thr_layout_vmnk());
|
||||
// cta rank -> logical cta idx
|
||||
auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)));
|
||||
|
||||
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
|
||||
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
|
||||
return detail::make_tma_copy_tiled<TmaType>(copy_op, gtensor, slayout, cta_t_map, cta_v_tile);
|
||||
}
|
||||
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class Cluster_Tiler,
|
||||
class... Args>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_copy_C_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (M, N, ...)
|
||||
SLayout const& slayout, // (MMA, MMA_M, MMA_N, ...)
|
||||
Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...)
|
||||
TiledMMA<Args...> const& mma)
|
||||
{
|
||||
// Keep only MN modes from MNK
|
||||
auto cluster_tiler_mn = remove<2>(cluster_tiler);
|
||||
// cluster tile coord -> gtensor coord
|
||||
auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mn); // (TILE_M, TILE_N, ...)
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_C(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_M, MMA_N, ...)
|
||||
|
||||
static_assert(is_same_v<CopyOp, SM90_TMA_LOAD> ||
|
||||
is_same_v<CopyOp, SM90_TMA_STORE> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>,
|
||||
"Unsupported TMA Op, expected a non-multicast TMA");
|
||||
|
||||
// No multicast, so only 1 CTA involved
|
||||
auto cta_t_map = Layout<_1,_0>{};
|
||||
|
||||
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
|
||||
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
|
||||
return detail::make_tma_copy_tiled<TmaType>(copy_op, gtensor, slayout, cta_t_map, cta_v_tile);
|
||||
}
|
||||
|
||||
////////////////////////////////////
|
||||
// Experimental Make TMA Atom
|
||||
///////////////////////////////////
|
||||
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class MMA_Tiler,
|
||||
class... Args,
|
||||
class ClusterShapeVMNK>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_atom_A_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (M, K, ...)
|
||||
SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...)
|
||||
MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...)
|
||||
TiledMMA<Args...> const& mma,
|
||||
ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K)
|
||||
{
|
||||
// Keep only MK modes from MNK
|
||||
auto mma_tiler_mk = remove<1>(mma_tiler);
|
||||
|
||||
// cluster tile coord -> gtensor coord
|
||||
auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_mk); // (TILE_M, TILE_K, ...)
|
||||
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_M, MMA_K, ...)
|
||||
|
||||
#if 0
|
||||
print("(tma_a) slayout: "); print(slayout); print("\n");
|
||||
print("(tma_a) mma_tiler_nk: "); print(mma_tiler_nk); print("\n");
|
||||
print("(tma_a) g_tile: "); print(g_tile); print("\n");
|
||||
print("(tma_a) mma_tiler: "); print(mma_tiler); print("\n");
|
||||
print("(tma_a) cta_v_tile: "); print(cta_v_tile); print("\n");
|
||||
#endif
|
||||
|
||||
// The size of the multicasting
|
||||
auto num_multicast = [&](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
|
||||
return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
|
||||
is_same_v<CopyOp, SM90_TMA_STORE> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
|
||||
return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
|
||||
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
|
||||
return detail::make_tma_copy_atom<TmaType>(copy_op, gtensor, slayout, num_multicast, cta_v_tile);
|
||||
}
|
||||
|
||||
template <class TmaInternalType = void,
|
||||
class CopyOp,
|
||||
class GEngine, class GLayout,
|
||||
class SLayout,
|
||||
class MMA_Tiler,
|
||||
class... Args,
|
||||
class ClusterShapeVMNK>
|
||||
CUTE_HOST
|
||||
auto
|
||||
make_tma_atom_B_sm100(CopyOp const& copy_op,
|
||||
Tensor<GEngine,GLayout> const& gtensor, // (N, K, ...)
|
||||
SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...)
|
||||
MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...)
|
||||
TiledMMA<Args...> const& mma,
|
||||
ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K)
|
||||
{
|
||||
// Keep only NK modes from MNK
|
||||
auto mma_tiler_nk = remove<0>(mma_tiler);
|
||||
// cluster tile coord -> gtensor coord
|
||||
auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_nk); // (TILE_N, TILE_K, ...)
|
||||
// cta val idx -> gmem mode
|
||||
auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_N, MMA_K, ...)
|
||||
|
||||
#if 0
|
||||
print("(tma_b) slayout: "); print(slayout); print("\n");
|
||||
print("(tma_b) mma_tiler_nk: "); print(mma_tiler_nk); print("\n");
|
||||
print("(tma_b) g_tile: "); print(g_tile); print("\n");
|
||||
print("(tma_b) mma_tiler: "); print(mma_tiler); print("\n");
|
||||
print("(tma_b) cta_v_tile: "); print(cta_v_tile); print("\n");
|
||||
#endif
|
||||
|
||||
// The size of the multicasting
|
||||
auto num_multicast = [&](){
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
|
||||
return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast
|
||||
} else
|
||||
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
|
||||
is_same_v<CopyOp, SM90_TMA_STORE> ||
|
||||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
|
||||
return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast
|
||||
} else {
|
||||
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
|
||||
}
|
||||
}();
|
||||
|
||||
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
|
||||
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
|
||||
return detail::make_tma_copy_atom<TmaType>(copy_op, gtensor, slayout, num_multicast, cta_v_tile);
|
||||
}
|
||||
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
} // end namespace cute
|
||||
@ -56,6 +56,13 @@ get_tma_swizzle_bits(Swizzle<B,M,S>)
|
||||
case 0: return TMA::SmemSwizzleBits::DISABLE;
|
||||
}
|
||||
} else
|
||||
|
||||
if constexpr (M == 5 || M == 6) {
|
||||
static_assert(B == 2, "Expected B = 2 when M == 5 or 6. Unsupported layout swizzle.");
|
||||
// S-condition as well?
|
||||
return TMA::SmemSwizzleBits::B128;
|
||||
} else
|
||||
|
||||
{
|
||||
static_assert(M < 0, "Unsupported layout swizzle.");
|
||||
}
|
||||
@ -78,9 +85,25 @@ get_tma_swizzle_base(Swizzle<B,M,S>)
|
||||
static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle.");
|
||||
return TMA::SmemSwizzleBase::SWIZZLE_BASE_16B;
|
||||
}
|
||||
|
||||
else if constexpr (M == 5) {
|
||||
static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle.");
|
||||
static_assert(S == 2, "Expected S = 2 when M == 5. Unsupported layout swizzle.");
|
||||
return TMA::SmemSwizzleBase::SWIZZLE_BASE_32B;
|
||||
} else if constexpr (M == 6) {
|
||||
static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle.");
|
||||
return TMA::SmemSwizzleBase::SWIZZLE_BASE_64B;
|
||||
}
|
||||
#if 1
|
||||
else {
|
||||
static_assert(4 <= M && M <= 6, "Expected 128b=16B=(2^4)B to 512b=64B=(2^6)B base swizzle.");
|
||||
}
|
||||
#else
|
||||
|
||||
else {
|
||||
static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle.");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class Layout>
|
||||
|
||||
@ -154,6 +154,10 @@ struct MMA_Atom<MMA_Traits<MMAOperation, Args...>>
|
||||
if constexpr (has_dereference<FrgTypeA>::value) {
|
||||
// If the intended FrgTypeA is a view (of the current tensor), forward the whole
|
||||
static_assert(is_same<ValTypeA, typename remove_cvref_t<ATensor>::value_type>::value
|
||||
|
||||
|| (sizeof_bits_v<typename remove_cvref_t<ATensor>::value_type> == 8 &&
|
||||
(sizeof_bits_v<ValTypeA> == 8 || sizeof_bits_v<ValTypeA> == 6 || sizeof_bits_v<ValTypeA> == 4))
|
||||
|
||||
, "Expecting ValTypeA type");
|
||||
return make_tensor<FrgTypeA>(static_cast<ATensor&&>(atensor));
|
||||
} else {
|
||||
@ -176,6 +180,10 @@ struct MMA_Atom<MMA_Traits<MMAOperation, Args...>>
|
||||
if constexpr (has_dereference<FrgTypeB>::value) {
|
||||
// If the intended FrgTypeB is a view (of the current tensor), forward the whole
|
||||
static_assert(is_same<ValTypeB, typename remove_cvref_t<BTensor>::value_type>::value
|
||||
|
||||
|| (sizeof_bits_v<typename remove_cvref_t<BTensor>::value_type> == 8 &&
|
||||
(sizeof_bits_v<ValTypeB> == 8 || sizeof_bits_v<ValTypeB> == 6 || sizeof_bits_v<ValTypeB> == 4))
|
||||
|
||||
, "Expecting ValTypeB type");
|
||||
return make_tensor<FrgTypeB>(static_cast<BTensor&&>(btensor));
|
||||
} else {
|
||||
@ -1109,4 +1117,5 @@ print_svg(TiledMMA<Args...> const &mma) {
|
||||
#include <cute/atom/mma_traits_sm80.hpp>
|
||||
#include <cute/atom/mma_traits_sm90.hpp>
|
||||
#include <cute/atom/mma_traits_sm90_gmma.hpp>
|
||||
#include <cute/atom/mma_traits_sm100.hpp>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
2425
include/cute/atom/mma_traits_sm100.hpp
Normal file
2425
include/cute/atom/mma_traits_sm100.hpp
Normal file
File diff suppressed because it is too large
Load Diff
109
include/cute/atom/partitioner.hpp
Normal file
109
include/cute/atom/partitioner.hpp
Normal file
@ -0,0 +1,109 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/type_traits>
|
||||
#else
|
||||
#include <type_traits>
|
||||
#endif
|
||||
|
||||
#include <cute/config.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
//
|
||||
// A generic tiling of thread-value layouts
|
||||
//
|
||||
|
||||
template <class Layout_TV_, // (tid,vid) -> coord [Need not be 2D...]
|
||||
class Tiler_MN_> // coord space
|
||||
struct TV_Tiler
|
||||
{
|
||||
using Tiler_MN = Tiler_MN_;
|
||||
using TiledLayout_TV = Layout_TV_;
|
||||
|
||||
// Tile a tensor or a layout from shape
|
||||
// (M,N,...)
|
||||
// to shape
|
||||
// ((ThrV,FrgV),(RestM,RestN,...))
|
||||
// where
|
||||
// ThrV: The threads local to a tile.
|
||||
// FrgV: The values local to a tile.
|
||||
// RestM: The values tiled in M.
|
||||
// RestN: The values tiled in N.
|
||||
template <class Tensor>
|
||||
CUTE_HOST_DEVICE constexpr static
|
||||
auto
|
||||
apply(Tensor&& tensor)
|
||||
{
|
||||
// If Layout_TV and Tiler_MN were composable in general, then this won't be needed!
|
||||
|
||||
// ((thr_id,val_id),(RestM,RestN,...))
|
||||
return zipped_divide(tensor, Tiler_MN{}).compose(TiledLayout_TV{}, _);
|
||||
}
|
||||
|
||||
template <class SliceCoord>
|
||||
struct TV_Partitioner
|
||||
{
|
||||
SliceCoord coord_;
|
||||
|
||||
template <class TargetTensor>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
partition(TargetTensor&& target) {
|
||||
Tensor thr_tensor = make_tensor(static_cast<TargetTensor&&>(target).data(), apply(target.layout()));
|
||||
return thr_tensor(coord_, repeat<rank_v<TargetTensor>>(_));
|
||||
}
|
||||
};
|
||||
|
||||
template <class SliceCoord>
|
||||
CUTE_HOST_DEVICE static
|
||||
auto
|
||||
get_slice(SliceCoord const& coord)
|
||||
{
|
||||
return TV_Partitioner<SliceCoord>{coord};
|
||||
}
|
||||
};
|
||||
|
||||
template <class Layout_TV,
|
||||
class Tiler_MN>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiler_impl(Layout_TV const&,
|
||||
Tiler_MN const&)
|
||||
{
|
||||
return TV_Tiler<Layout_TV, Tiler_MN>{};
|
||||
}
|
||||
|
||||
}
|
||||
@ -119,12 +119,16 @@ template <size_t N, class T>
|
||||
CUTE_HOST_DEVICE constexpr T getv(EBO<N, T, true> const&)
|
||||
{ return {}; }
|
||||
|
||||
// This is a work around approach to solve a shared memory misalign issue (https://github.com/NVIDIA/cutlass/issues/1250).
|
||||
// Will remove this work around implementation once the corresponding fix in compiler is released.
|
||||
struct dummy_EBO_base {};
|
||||
|
||||
// Specialization for types T that are not empty;
|
||||
// the "dynamic tuple leaf." Valid T here include int,
|
||||
// any other integral or floating-point type,
|
||||
// or any semiregular type for which std::is_empty_v<T> is false.
|
||||
template <size_t N, class T>
|
||||
struct EBO<N, T, false>
|
||||
struct EBO<N, T, false> : private dummy_EBO_base
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
EBO() : t_{} {}
|
||||
|
||||
@ -78,6 +78,7 @@ using int_byte_t = typename int_byte<N>::type;
|
||||
using uint1_t = cutlass::uint1b_t;
|
||||
using uint2_t = cutlass::uint2b_t;
|
||||
using uint4_t = cutlass::uint4b_t;
|
||||
using uint6_t = cutlass::uint6b_t;
|
||||
using CUTE_STL_NAMESPACE::uint8_t;
|
||||
using CUTE_STL_NAMESPACE::uint16_t;
|
||||
using CUTE_STL_NAMESPACE::uint32_t;
|
||||
@ -88,6 +89,7 @@ template <int N> struct uint_bit;
|
||||
template <> struct uint_bit< 1> { using type = uint1_t; };
|
||||
template <> struct uint_bit< 2> { using type = uint2_t; };
|
||||
template <> struct uint_bit< 4> { using type = uint4_t; };
|
||||
template <> struct uint_bit< 6> { using type = uint6_t; };
|
||||
template <> struct uint_bit< 8> { using type = uint8_t; };
|
||||
template <> struct uint_bit< 16> { using type = uint16_t; };
|
||||
template <> struct uint_bit< 32> { using type = uint32_t; };
|
||||
|
||||
@ -73,6 +73,29 @@ using cutlass::uint4b_t;
|
||||
using cutlass::bin1_t;
|
||||
|
||||
|
||||
using cutlass::float_ue4m3_t;
|
||||
using cutlass::float_ue8m0_t;
|
||||
|
||||
using cutlass::uint6b_t;
|
||||
using cutlass::float_e2m1_t;
|
||||
using cutlass::float_e2m3_t;
|
||||
using cutlass::float_e3m2_t;
|
||||
|
||||
using cutlass::type_erased_dynamic_float6_t;
|
||||
using cutlass::type_erased_dynamic_float4_t;
|
||||
|
||||
namespace detail {
|
||||
using cutlass::detail::float_e2m1_unpacksmem_t;
|
||||
using cutlass::detail::float_e2m3_unpacksmem_t;
|
||||
using cutlass::detail::float_e3m2_unpacksmem_t;
|
||||
using cutlass::detail::float_e2m3_unpack8bits_t;
|
||||
using cutlass::detail::float_e3m2_unpack8bits_t;
|
||||
using cutlass::detail::type_erased_dynamic_float4_unpacksmem_t;
|
||||
using cutlass::detail::type_erased_dynamic_float6_unpacksmem_t;
|
||||
};
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Print utility
|
||||
//
|
||||
@ -133,4 +156,26 @@ pretty_print(float_e5m2_t t) {
|
||||
printf("%*.2f", 8, static_cast<float>(t));
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
cutlass::detail::FpEncoding Encoding,
|
||||
class Derived
|
||||
>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
print(cutlass::float_exmy_base<Encoding, Derived> a) {
|
||||
printf("%f", static_cast<float>(a));
|
||||
}
|
||||
|
||||
template <
|
||||
cutlass::detail::FpEncoding Encoding,
|
||||
class Derived
|
||||
>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
pretty_print_float_exmy_base(cutlass::float_exmy_base<Encoding, Derived> t) {
|
||||
printf("%*.2f", 8, static_cast<float>(t));
|
||||
}
|
||||
|
||||
|
||||
} // namespace cute
|
||||
|
||||
@ -284,6 +284,96 @@ recast_ptr(rmem_ptr<P> const& ptr) {
|
||||
return make_rmem_ptr(recast_ptr<NewT>(ptr.get()));
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// tmem_ptr -- a typed, word-addressed, non-dereferencable "pointer"
|
||||
//
|
||||
|
||||
template <class T>
|
||||
struct tmem_ptr
|
||||
{
|
||||
using value_type = remove_cv_t<T>;
|
||||
using element_type = T;
|
||||
using reference = T;
|
||||
|
||||
// Right-shift value for the offset scaling -- TMEM uses word-addressing
|
||||
static constexpr int32_t OffsetShift = log_2(trait_ratio(sizeof_bits<uint32_t>{}, sizeof_bits<T>{}));
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tmem_ptr(uint32_t addr = 0) : addr_(addr) {}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint32_t const& get() const {
|
||||
return addr_;
|
||||
}
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint32_t& get() {
|
||||
return addr_;
|
||||
}
|
||||
|
||||
template <class T_ = T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
value_type operator*() const {
|
||||
static_assert(dependent_false<T_>, "Attempting to dereference a tmem_ptr, want raw_pointer_cast() for address instead?");
|
||||
return value_type{};
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
reference operator[](uint32_t const& i) const { return *(*this + i); }
|
||||
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tmem_ptr operator+(uint32_t const& i) const {
|
||||
//return {addr_ + shiftr(i, OffsetShift)}; // Shift the offset for word-addressing
|
||||
return {addr_ + rotr(i, OffsetShift)}; // Rotate the offset to keep subword indices in the unused high 8bits for debug
|
||||
}
|
||||
|
||||
// TMEM "Address" with active mask 0x007F.01FF
|
||||
// The upper 16 bits, the 0x007F portion, refers to the 128 DP lanes
|
||||
// The lower 16 bits, the 0x01FF portion, refers to the 512 COL lanes
|
||||
union {
|
||||
uint32_t addr_;
|
||||
struct {
|
||||
uint16_t col_;
|
||||
uint8_t dp_;
|
||||
uint8_t idx_; // Hijack the top 8bits for the sub-word idx to avoid an extra reg.
|
||||
// Assert this is 0 on every access?
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
template <class T, class = void>
|
||||
struct is_tmem : false_type {};
|
||||
template <class T> // Found the tmem
|
||||
struct is_tmem<tmem_ptr<T>> : true_type {};
|
||||
template <class P> // Recurse on ::iterator, if possible
|
||||
struct is_tmem<P, void_t<typename P::iterator>> : is_tmem<typename P::iterator> {};
|
||||
template <class P>
|
||||
constexpr bool is_tmem_v = is_tmem<P>::value;
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
tmem_ptr<T>
|
||||
make_tmem_ptr(uint32_t addr = 0) {
|
||||
return tmem_ptr<T>(addr);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
uint32_t
|
||||
raw_pointer_cast(tmem_ptr<T> const& ptr) {
|
||||
return ptr.get();
|
||||
}
|
||||
|
||||
// TMEM accounts for subword/superword elements already due to the offset shift based on sizeof_bits
|
||||
// Thus, this is a trivial recast equivalent to reinterpret_cast<NewT*>
|
||||
template <class NewT, class T>
|
||||
CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
recast_ptr(tmem_ptr<T> const& ptr) {
|
||||
return tmem_ptr<NewT>{ptr.addr_};
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Display utilities
|
||||
//
|
||||
@ -306,6 +396,14 @@ CUTE_HOST_DEVICE void print(rmem_ptr<T> ptr)
|
||||
printf("rmem_"); print(ptr.get());
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void print(tmem_ptr<T> ptr)
|
||||
{
|
||||
printf("tmem_["); print(sizeof_bits<T>::value); printf("b](0x%04x.%04x)", ptr.addr_ >> 16, ptr.addr_ & 0xFFFF);
|
||||
}
|
||||
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr<T> ptr)
|
||||
@ -325,6 +423,13 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr<T> ptr)
|
||||
return os << "rmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, tmem_ptr<T> ptr)
|
||||
{
|
||||
return os << "tmem_[" << int(sizeof_bits<T>::value) << "b](" << ptr.addr_ << ")";
|
||||
}
|
||||
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -95,6 +95,9 @@ template <class... Iters>
|
||||
struct is_smem<ZipIterator<Iters...>> : conjunction<is_smem<Iters>...> {};
|
||||
template <class... Iters>
|
||||
struct is_gmem<ZipIterator<Iters...>> : conjunction<is_gmem<Iters>...> {};
|
||||
template <class... Iters>
|
||||
struct is_tmem<ZipIterator<Iters...>> : conjunction<is_tmem<Iters>...> {};
|
||||
|
||||
// A tuple of Layouts that operates on each Layout symmetrically
|
||||
// The Layouts need to have compatible shapes and ranks.
|
||||
// The ZipLayout presents the intersection of the domain of its component Layouts.
|
||||
|
||||
@ -255,7 +255,12 @@ pretty_print(double v) {
|
||||
template <class T>
|
||||
CUTE_HOST_DEVICE void
|
||||
pretty_print(T t) {
|
||||
constexpr auto has_print_exmy_base = cute::is_valid([](auto t) -> decltype(pretty_print_float_exmy_base(t)) {}, t);
|
||||
if constexpr (has_print_exmy_base) {
|
||||
pretty_print_float_exmy_base(t);
|
||||
} else {
|
||||
printf(" "); print(t);
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
constexpr int sm100_smem_capacity_bytes = 232448;
|
||||
#if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
|
||||
/// Computes laneId within a warp
|
||||
@ -93,6 +94,12 @@ struct Sm90 {
|
||||
static int const kMinComputeCapability = 90;
|
||||
};
|
||||
|
||||
|
||||
struct Sm100 {
|
||||
static int const kMinComputeCapability = 100;
|
||||
};
|
||||
|
||||
|
||||
/// Triggers a breakpoint on the device
|
||||
CUTLASS_DEVICE
|
||||
void device_breakpoint() {
|
||||
|
||||
@ -36,12 +36,21 @@
|
||||
|
||||
#include <cutlass/arch/memory_sm75.h>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm100_tma.hpp>
|
||||
#include <cutlass/arch/config.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUDA_BARRIER_ENABLED 1
|
||||
#else
|
||||
#define CUDA_BARRIER_ENABLED 0
|
||||
#endif
|
||||
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED))
|
||||
#define CUTLASS_ARCH_TCGEN_ENABLED 1
|
||||
#endif
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
/// @brief
|
||||
namespace arch {
|
||||
@ -140,6 +149,15 @@ void initialize_barrier_array_pair_aligned(uint64_t *full_barriers_ptr, uint64_t
|
||||
} // namespace detail end
|
||||
|
||||
|
||||
|
||||
|
||||
// There are 16 Named Barriers provided by Hardware starting in Hopper
|
||||
// Their IDs are in the range 0-15
|
||||
// Number of threads syncing using the barrier must be a multiple of warp-size
|
||||
// ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads)
|
||||
// may use it and conflict with other uses.
|
||||
|
||||
|
||||
// Enumerates the reserved named barriers to avoid potential conflicts
|
||||
// This enum class specifies the NamedBarriers reserved by CUTLASS.
|
||||
enum class ReservedNamedBarriers {
|
||||
@ -148,6 +166,7 @@ enum class ReservedNamedBarriers {
|
||||
TransformBarrier = 3,
|
||||
StreamkBarrier0 = 4,
|
||||
StreamkBarrier1 = 5
|
||||
, TmemAllocBarrier = 6
|
||||
, FirstUserBarrier = StreamkBarrier1 + 1
|
||||
};
|
||||
|
||||
@ -735,6 +754,152 @@ void cpasync_barrier_arrive_noinc(uint64_t const* smem_ptr) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void umma_arrive(uint64_t const* smem_ptr) {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
if (cute::elect_one_sync()) {
|
||||
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
|
||||
:
|
||||
:"r"(bar_intptr));
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
//UMMA arrive for MMA_2x1SM
|
||||
CUTLASS_DEVICE
|
||||
void umma_arrive_2x1SM(uint64_t const* smem_ptr) {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
if (cute::elect_one_sync()) {
|
||||
asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];"
|
||||
:
|
||||
:"r"(bar_intptr));
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// UMMA arrive for MMA_1sm + TMA_LOAD_MULTICAST combination
|
||||
CUTLASS_DEVICE
|
||||
void umma_arrive_multicast(uint64_t const* smem_ptr, uint16_t cta_mask) {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
if(cute::elect_one_sync()) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
|
||||
"}"
|
||||
:
|
||||
:"r"(bar_intptr), "h"(cta_mask));
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination
|
||||
CUTLASS_DEVICE
|
||||
void umma_arrive_multicast_2x1SM(uint64_t const* smem_ptr, uint16_t cta_mask) {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
if (cute::elect_one_sync()) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
|
||||
"}"
|
||||
:
|
||||
:"r"(bar_intptr), "h"(cta_mask));
|
||||
}
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Temporary solution for sparse kernel.
|
||||
// Will remove this when we done tightly elect_one wrap.
|
||||
CUTLASS_DEVICE
|
||||
void umma_arrive_multicast_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .b16 lo, hi;\n\t"
|
||||
"mov.b32 {lo, hi}, %1;\n\t"
|
||||
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t"
|
||||
"}"
|
||||
:
|
||||
:"r"(bar_intptr), "r"(uint32_t(cta_mask)));
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Temporary solution for sparse kernel.
|
||||
// UMMA arrive for MMA_2x1SM + TMA_LOAD_MULTICAST combination
|
||||
CUTLASS_DEVICE
|
||||
void umma_arrive_multicast_2x1SM_no_elect(uint64_t const* smem_ptr, uint16_t cta_mask) {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .b16 lo, hi;\n\t"
|
||||
"mov.b32 {lo, hi}, %1;\n\t"
|
||||
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], lo; \n\t"
|
||||
"}"
|
||||
:
|
||||
:"r"(bar_intptr), "r"(uint32_t(cta_mask)));
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Always arrive on even SM of collaborating 2 SMs.
|
||||
CUTLASS_DEVICE
|
||||
void umma_arrive_2x1SM_sm0(uint64_t const* smem_ptr) {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr) & cute::Sm100MmaPeerBitMask;
|
||||
asm volatile (
|
||||
"{\n\t"
|
||||
"mbarrier.arrive.shared::cluster.b64 _, [%0];\n\t"
|
||||
"}"
|
||||
:
|
||||
: "r"(bar_intptr));
|
||||
|
||||
#else
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE static void fence_view_async_tmem_load() {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
asm volatile (
|
||||
"{\n\t"
|
||||
"tcgen05.wait::ld.sync.aligned; \n"
|
||||
"}"
|
||||
::);
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTE_DEVICE static void fence_view_async_tmem_store() {
|
||||
#if defined(CUTLASS_ARCH_TCGEN_ENABLED)
|
||||
asm volatile (
|
||||
"{\n\t"
|
||||
"tcgen05.wait::st.sync.aligned; \n"
|
||||
"}"
|
||||
::);
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // end namespace arch
|
||||
} // end namespace cutlass
|
||||
|
||||
@ -51,21 +51,32 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 2)
|
||||
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 2))
|
||||
#define CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SM90 Modifiable
|
||||
// SM90 Modifiable TMA
|
||||
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 3))
|
||||
#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED 1
|
||||
#if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900)
|
||||
#if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900)
|
||||
#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED 1
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
||||
#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED 1
|
||||
#if (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ == 8)
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && \
|
||||
!defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
#undef CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1000 && \
|
||||
!defined(__CUDA_ARCH_FEAT_SM100_ALL)
|
||||
#undef CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@ -79,7 +90,29 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SM100, SM100a
|
||||
#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
#define CUTLASS_ARCH_MMA_SM100_SUPPORTED 1
|
||||
#if (!defined(CUTLASS_ARCH_MMA_SM100_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1000)
|
||||
#define CUTLASS_ARCH_MMA_SM100_ENABLED 1
|
||||
|
||||
#if (!defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM100_ALL))
|
||||
#define CUTLASS_ARCH_MMA_SM100A_ENABLED 1
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED))
|
||||
# define CUTLASS_ARCH_CLC_ENABLED
|
||||
#endif
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -129,6 +129,11 @@ struct OpClassWmmaTensorOp {};
|
||||
/// Tag classifying operators as Tensor Core with structure sparse operations.
|
||||
struct OpClassSparseTensorOp {};
|
||||
|
||||
|
||||
/// Tag classifying operators as Tensor Core with blockScaled
|
||||
struct OpClassBlockScaledTensorOp {};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Matrix multiply-add operation
|
||||
|
||||
@ -2567,7 +2567,6 @@ struct bit_not<Array<uint1b_t, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// bit_xor
|
||||
template <int N>
|
||||
struct bit_xor<Array<uint1b_t, N>> {
|
||||
@ -2590,6 +2589,137 @@ struct bit_xor<Array<uint1b_t, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused and-popc-add
|
||||
template <typename T, int N>
|
||||
struct and_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
and_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(a[i], b[i], c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
and_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(a[i], scalar, c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
and_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(scalar, b[i], c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Fused or-popc-add
|
||||
template <typename T, int N>
|
||||
struct or_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
or_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(a[i], b[i], c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
or_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(a[i], scalar, c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
or_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(scalar, b[i], c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused xor-popc-add
|
||||
template <typename T, int N>
|
||||
struct xor_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
xor_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(a[i], b[i], c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
xor_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(a[i], scalar, c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
|
||||
Array<T, N> result;
|
||||
xor_popc_add<T> scalar_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = scalar_op(scalar, b[i], c[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Operator overloads
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -38,6 +38,8 @@
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include <cute/arch/cluster_sm100.hpp>
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/type_traits>
|
||||
#else
|
||||
@ -49,6 +51,11 @@
|
||||
# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED
|
||||
#endif
|
||||
|
||||
#ifndef CUDA_ENABLE_PREFERRED_CLUSTER
|
||||
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||
# define CUDA_ENABLE_PREFERRED_CLUSTER
|
||||
#endif
|
||||
#endif
|
||||
namespace cutlass {
|
||||
|
||||
#ifndef NDEBUG
|
||||
@ -78,7 +85,13 @@ struct ClusterLauncher {
|
||||
struct LaunchConfig {
|
||||
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
cudaLaunchConfig_t launch_config;
|
||||
|
||||
#if defined(CUDA_ENABLE_PREFERRED_CLUSTER)
|
||||
constexpr static int numAttrs = 3;
|
||||
#else
|
||||
|
||||
constexpr static int numAttrs = 2;
|
||||
#endif
|
||||
cudaLaunchAttribute launch_attribute[numAttrs];
|
||||
// Commonly used utility functions
|
||||
dim3 gridDim() { return launch_config.gridDim; }
|
||||
@ -143,6 +156,7 @@ struct ClusterLauncher {
|
||||
size_t const smem_size = 0,
|
||||
cudaStream_t cuda_stream = 0,
|
||||
bool launch_with_pdl = false
|
||||
, dim3 const fallback_cluster_dims = {0, 0, 0}
|
||||
) {
|
||||
LaunchConfig cluster_launch_config;
|
||||
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
@ -151,9 +165,37 @@ struct ClusterLauncher {
|
||||
auto numAttrs = cluster_launch_config.numAttrs;
|
||||
|
||||
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
|
||||
|
||||
bool have_fallback = fallback_cluster_dims.x * fallback_cluster_dims.y * fallback_cluster_dims.z > 0;
|
||||
|
||||
if (have_fallback) {
|
||||
launch_attribute[0].val.clusterDim = {fallback_cluster_dims.x, fallback_cluster_dims.y, fallback_cluster_dims.z};
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: Setting fallback ClusterDims = "
|
||||
"(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n");
|
||||
}
|
||||
else {
|
||||
|
||||
launch_attribute[0].val.clusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z};
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: Setting ClusterDims = "
|
||||
"(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
|
||||
|
||||
}
|
||||
|
||||
#if defined(CUDA_ENABLE_PREFERRED_CLUSTER)
|
||||
if (have_fallback) {
|
||||
if (cute::initialize_preferred_cluster_launch(nullptr, grid_dims, cluster_dims, fallback_cluster_dims)) {
|
||||
launch_attribute[1].id = cudaLaunchAttributePreferredClusterDimension;
|
||||
launch_attribute[1].val.preferredClusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z};
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: Setting preferred ClusterDims = "
|
||||
"(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
|
||||
}
|
||||
}
|
||||
else {
|
||||
numAttrs--;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
// PDL attributes
|
||||
launch_attribute[numAttrs - 1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
launch_attribute[numAttrs - 1].val.programmaticStreamSerializationAllowed = 1;
|
||||
@ -198,7 +240,7 @@ struct ClusterLauncher {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = "
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: Launching GridDims = "
|
||||
"(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), "
|
||||
"And ClusterDims = "
|
||||
"(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n");
|
||||
@ -212,6 +254,53 @@ struct ClusterLauncher {
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// This is the method we expect to use going forward
|
||||
// Launch a preferred cluster grid
|
||||
static inline CUTLASS_HOST
|
||||
Status launch_with_fallback_cluster(
|
||||
dim3 const grid_dims,
|
||||
dim3 const preferred_cluster_dims,
|
||||
dim3 const fallback_cluster_dims,
|
||||
dim3 const block_dims,
|
||||
size_t const smem_size,
|
||||
cudaStream_t cuda_stream,
|
||||
void const* kernel,
|
||||
void** kernel_params,
|
||||
bool launch_with_pdl = false) {
|
||||
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, preferred_cluster_dims,
|
||||
block_dims, smem_size, cuda_stream, launch_with_pdl, fallback_cluster_dims);
|
||||
|
||||
auto launch_grid_dims = cluster_launch_config.gridDim();
|
||||
if (check_cluster_dims(launch_grid_dims, preferred_cluster_dims) != Status::kSuccess) {
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
auto init_status = init(kernel);
|
||||
if (init_status != Status::kSuccess) {
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting.");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: Launching \n\tGridDims = "
|
||||
"(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), "
|
||||
"\n\tPreferred ClusterDims = "
|
||||
"(" << preferred_cluster_dims.x << ", " << preferred_cluster_dims.y << ", " << preferred_cluster_dims.z << "),"
|
||||
"\n\tFallback ClusterDims = "
|
||||
"(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n");
|
||||
|
||||
cutlass::arch::synclog_setup();
|
||||
cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params);
|
||||
Return_Status(status);
|
||||
#else
|
||||
CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch.");
|
||||
return Status::kInvalid;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
193
include/cutlass/conv/collective/builders/sm100_common.inl
Normal file
193
include/cutlass/conv/collective/builders/sm100_common.inl
Normal file
@ -0,0 +1,193 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cute/atom/copy_traits_sm100_im2col.hpp"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/dispatch_policy.hpp"
|
||||
#include "cutlass/detail/layout.hpp"
|
||||
#include "cutlass/conv/collective/builders/sm90_common.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::conv::collective::detail {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Collective tile traits struct that serves as a type list containing a tensor's mem layouts and atoms
|
||||
template<
|
||||
class GmemTiledCopy_,
|
||||
class SmemLayoutAtom_,
|
||||
class TmemLayoutAtom_ = void
|
||||
>
|
||||
struct Sm100ImplicitGemmTileTraits {
|
||||
using GmemTiledCopy = GmemTiledCopy_;
|
||||
using SmemLayoutAtom = SmemLayoutAtom_;
|
||||
using TmemLayoutAtom = TmemLayoutAtom_;
|
||||
};
|
||||
|
||||
template <class ClusterShapeMNK, class AtomThrId>
|
||||
constexpr auto
|
||||
sm100_cluster_shape_to_im2col_tma_atom_A(ClusterShapeMNK cluster_shape_mnk, AtomThrId atom_thr_id) {
|
||||
static_assert(cute::rank(cluster_shape_mnk) == 3);
|
||||
constexpr bool IsDynamicCluster = not cute::is_static_v<ClusterShapeMNK>;
|
||||
|
||||
if constexpr (cute::size(atom_thr_id) == 2) {
|
||||
if constexpr (!IsDynamicCluster) {
|
||||
static_assert(cute::size<0>(cluster_shape_mnk) % 2 == 0, "Cluster shape not divisible by MMA size");
|
||||
if constexpr (cute::size<1>(cluster_shape_mnk) == 1) {
|
||||
return cute::SM100_TMA_2SM_LOAD_IM2COL{};
|
||||
}
|
||||
else {
|
||||
return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{};
|
||||
}
|
||||
}
|
||||
else {
|
||||
return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{};
|
||||
}
|
||||
}
|
||||
else if constexpr (size(atom_thr_id) == 1) {
|
||||
if constexpr (!IsDynamicCluster) {
|
||||
return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::size<1>(cluster_shape_mnk));
|
||||
}
|
||||
else {
|
||||
// In the case of dynamic cluster, multicast decision is not known at compile time.
|
||||
// A multicast instruction is forced by passing a cute::Int<2>{} to this helper.
|
||||
return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::Int<2>{});
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<ClusterShapeMNK>,
|
||||
"Unsupported Configuration for SM100 TMA");
|
||||
}
|
||||
}
|
||||
|
||||
template <class ClusterShapeMNK, class AtomThrId>
|
||||
constexpr auto
|
||||
sm100_cluster_shape_to_im2col_tma_atom_B(ClusterShapeMNK cluster_shape_mnk, AtomThrId atom_thr_id) {
|
||||
static_assert(cute::rank(cluster_shape_mnk) == 3);
|
||||
constexpr bool IsDynamicCluster = not cute::is_static_v<ClusterShapeMNK>;
|
||||
|
||||
if constexpr (cute::size(atom_thr_id) == 2) {
|
||||
if constexpr (!IsDynamicCluster) {
|
||||
static_assert(cute::size<0>(cluster_shape_mnk) % 2 == 0, "Cluster shape not divisible by MMA size");
|
||||
if constexpr (cute::size<0>(cluster_shape_mnk) == 2) {
|
||||
return cute::SM100_TMA_2SM_LOAD_IM2COL{};
|
||||
}
|
||||
else {
|
||||
return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{};
|
||||
}
|
||||
}
|
||||
else {
|
||||
return cute::SM100_TMA_2SM_LOAD_IM2COL_MULTICAST{};
|
||||
}
|
||||
} else if constexpr (size(atom_thr_id) == 1) {
|
||||
if constexpr (!IsDynamicCluster) {
|
||||
return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::size<0>(cluster_shape_mnk));
|
||||
}
|
||||
else {
|
||||
// In the case of dynamic cluster, multicast decision is not known at compile time.
|
||||
// A multicast instruction is forced by passing a cute::Int<2>{} to this helper.
|
||||
return detail::sm90_cluster_shape_to_im2col_tma_atom(cute::Int<2>{});
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<ClusterShapeMNK>,
|
||||
"Unsupported Configuration for SM100 TMA");
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
UMMA::Major UmmaMajorA,
|
||||
UMMA::Major UmmaMajorB,
|
||||
class KernelScheduleType
|
||||
>
|
||||
constexpr auto
|
||||
sm100_make_tiled_mma() {
|
||||
// MMA_2SM requested
|
||||
if constexpr (cute::is_same_v<KernelScheduleType, KernelImplicitTmaWarpSpecialized2SmSm100>) {
|
||||
return cutlass::gemm::collective::detail::sm100_make_2sm_trivial_tiled_mma<
|
||||
ElementA, ElementB, ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>();
|
||||
}
|
||||
// MMA_1SM requested
|
||||
else if constexpr (cute::is_same_v<KernelScheduleType, KernelImplicitTmaWarpSpecialized1SmSm100>) {
|
||||
return cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma<
|
||||
ElementA, ElementB, ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>();
|
||||
}
|
||||
// Auto scheduling requested
|
||||
else if constexpr (cute::is_same_v<KernelScheduleType, KernelScheduleAuto>) {
|
||||
// Static cluster
|
||||
if constexpr (cute::is_static_v<ClusterShape_MNK>) {
|
||||
// For MMA_2SM we need a cluster shape that is multiple of 2x1
|
||||
// and only M=128 and M=256 are supported, otherwise, fall back to MMA_1SM
|
||||
if constexpr (cute::size<0>(ClusterShape_MNK{}) % 2 == 0 &&
|
||||
cute::size<0>(TileShape_MNK{}) % 128 == 0) {
|
||||
return cutlass::gemm::collective::detail::sm100_make_2sm_trivial_tiled_mma<
|
||||
ElementA, ElementB, ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>();
|
||||
}
|
||||
else {
|
||||
return cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma<
|
||||
ElementA, ElementB, ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>();
|
||||
}
|
||||
// Dynamic cluster shape means we cannot assume we can use 2SM MMA
|
||||
}
|
||||
else {
|
||||
return cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma<
|
||||
ElementA, ElementB, ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB>();
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>,
|
||||
"Unsupported policy for SM100 collective builder.");
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::conv::collective::detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
225
include/cutlass/conv/collective/builders/sm100_umma_builder.inl
Normal file
225
include/cutlass/conv/collective/builders/sm100_umma_builder.inl
Normal file
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
//
|
||||
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/conv/collective/builders/sm100_common.inl"
|
||||
#include "cutlass/conv/collective/builders/sm90_gmma_builder.inl"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::conv::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
conv::Operator ConvOp,
|
||||
class ElementA,
|
||||
class GmemLayoutA,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutB,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNKL, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL)
|
||||
class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1)
|
||||
class StageCountType,
|
||||
class KernelScheduleType
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassTensorOp,
|
||||
ConvOp,
|
||||
ElementA,
|
||||
GmemLayoutA,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutB,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNKL,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelImplicitTmaWarpSpecialized1SmSm100> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelImplicitTmaWarpSpecialized2SmSm100> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelStridedDgradTmaWs1SmSm100> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelStridedDgradTmaWs2SmSm100> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelScheduleAuto>) &&
|
||||
((sizeof(ElementA) * AlignmentA) % cutlass::gemm::collective::detail::tma_alignment_bytes == 0) &&
|
||||
((sizeof(ElementB) * AlignmentB) % cutlass::gemm::collective::detail::tma_alignment_bytes == 0)>> {
|
||||
private:
|
||||
// For fprop, majorA = K, major B = K;
|
||||
// For wgrad, majorA = MN, major B = MN;
|
||||
// For dgrad, majorA = K, major B = MN;
|
||||
static constexpr cute::UMMA::Major UmmaMajorA =
|
||||
(ConvOp == conv::Operator::kWgrad) ? cute::UMMA::Major::MN : cute::UMMA::Major::K;
|
||||
static constexpr cute::UMMA::Major UmmaMajorB =
|
||||
(ConvOp == conv::Operator::kFprop) ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
using TileShape_MNK = decltype(cute::take<0,3>(TileShape_MNKL{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||
|
||||
static constexpr auto
|
||||
get_tiled_mma_schedule() {
|
||||
if constexpr (cute::is_same_v<KernelScheduleType, KernelStridedDgradTmaWs1SmSm100>) {
|
||||
return KernelImplicitTmaWarpSpecialized1SmSm100{};
|
||||
}
|
||||
else if constexpr (cute::is_same_v<KernelScheduleType, KernelStridedDgradTmaWs2SmSm100>) {
|
||||
return KernelImplicitTmaWarpSpecialized2SmSm100{};
|
||||
}
|
||||
else {
|
||||
return KernelScheduleType{};
|
||||
}
|
||||
}
|
||||
|
||||
using TiledMmaSchedule = decltype(get_tiled_mma_schedule());
|
||||
using TiledMma = decltype(detail::sm100_make_tiled_mma<ElementAMma, ElementBMma, ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
UmmaMajorA, UmmaMajorB, TiledMmaSchedule>());
|
||||
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
|
||||
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
|
||||
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
|
||||
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
|
||||
static constexpr auto
|
||||
get_tma_atom_A() {
|
||||
if constexpr (cute::is_same_v<KernelScheduleType,KernelStridedDgradTmaWs1SmSm100> ||
|
||||
cute::is_same_v<KernelScheduleType,KernelStridedDgradTmaWs2SmSm100>) {
|
||||
static_assert(ConvOp == conv::Operator::kDgrad, "Operator+Schedule mismatch");
|
||||
return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(ClusterShape_MNK{}, AtomThrID{});
|
||||
}
|
||||
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
||||
return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(ClusterShape_MNK{}, AtomThrID{});
|
||||
}
|
||||
else {
|
||||
return cutlass::conv::collective::detail::sm100_cluster_shape_to_im2col_tma_atom_A(ClusterShape_MNK{}, AtomThrID{});
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto
|
||||
get_tma_atom_B() {
|
||||
if constexpr (cute::is_same_v<KernelScheduleType,KernelStridedDgradTmaWs1SmSm100> ||
|
||||
cute::is_same_v<KernelScheduleType,KernelStridedDgradTmaWs2SmSm100>) {
|
||||
static_assert(ConvOp == conv::Operator::kDgrad, "Operator+Schedule mismatch");
|
||||
return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{});
|
||||
}
|
||||
else if constexpr (ConvOp == conv::Operator::kWgrad) {
|
||||
return cutlass::conv::collective::detail::sm100_cluster_shape_to_im2col_tma_atom_B(ClusterShape_MNK{}, AtomThrID{});
|
||||
}
|
||||
else {
|
||||
return cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{});
|
||||
}
|
||||
}
|
||||
|
||||
// For wgrad kernel, tensor A uses tma tiled mode and tensor B uses tma im2col mode.
|
||||
using GmemTiledCopyA = decltype(get_tma_atom_A());
|
||||
using GmemTiledCopyB = decltype(get_tma_atom_B());
|
||||
|
||||
using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
|
||||
using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorA, ElementAMma, BlockTileA_M, BlockTileA_K>());
|
||||
|
||||
using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
|
||||
using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorB, ElementBMma, BlockTileB_N, BlockTileB_K>());
|
||||
|
||||
// Calculate SMEM matrix A and B buffers' pipeline stages
|
||||
static constexpr uint32_t AccumulatorPipelineStageCount = 2;
|
||||
static constexpr uint32_t SchedulerPipelineStageCount = 2;
|
||||
static constexpr uint32_t CLCResponseSize = 16;
|
||||
|
||||
// AccumulatorPipeline = PipelineUmmaAsync
|
||||
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
|
||||
// CLCPipeline = PipelineCLCFetchAsync
|
||||
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
||||
// LoadOrderBarrier = OrderedSequenceBarrier<1,2>
|
||||
static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage);
|
||||
// CLC (scheduler) response
|
||||
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * CLCResponseSize;
|
||||
// CLC Throttle pipeline storage
|
||||
static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
|
||||
// Tmem dealloc
|
||||
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
||||
// Tmem ptr storage
|
||||
static constexpr auto TmemBasePtrsStorage = SchedulerPipelineStageCount * sizeof(uint32_t);
|
||||
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
|
||||
static constexpr auto KernelSmemCarveout = static_cast<int>( AccumulatorPipelineStorage +
|
||||
CLCPipelineStorage +
|
||||
LoadOrderBarrierStorage +
|
||||
TmemDeallocStorage +
|
||||
CLCThrottlePipelineStorage +
|
||||
CLCResponseStorage +
|
||||
TmemBasePtrsStorage);
|
||||
// Reduce SMEM capacity available for buffers considering barrier allocations.
|
||||
static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<
|
||||
Sm100ReducedSmemCapacityBytes, ElementAMma, ElementBMma, SmemTileShape>(StageCountType{});
|
||||
|
||||
constexpr static int NumSpatialDimensions = detail::gmem_layout_tags_to_spatial_dims<GmemLayoutA, GmemLayoutB>();
|
||||
|
||||
using DispatchPolicy = cutlass::conv::MainloopSm100TmaUmmaWarpSpecializedImplicitGemm<
|
||||
ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK>;
|
||||
|
||||
public:
|
||||
using CollectiveOp = cutlass::conv::collective::CollectiveConv<
|
||||
DispatchPolicy,
|
||||
TileShape_MNKL,
|
||||
ElementA,
|
||||
ElementB,
|
||||
TiledMma,
|
||||
detail::Sm100ImplicitGemmTileTraits<GmemTiledCopyA, SmemLayoutAtomA>,
|
||||
detail::Sm100ImplicitGemmTileTraits<GmemTiledCopyB, SmemLayoutAtomB>
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::conv::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -90,4 +90,5 @@ struct CollectiveBuilder {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "builders/sm90_gmma_builder.inl"
|
||||
#include "builders/sm100_umma_builder.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -59,4 +59,5 @@ struct CollectiveConv {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp"
|
||||
#include "sm100_implicit_gemm_umma_warpspecialized.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user