v4.0 update. (#2371)
This commit is contained in:
62
CHANGELOG.md
62
CHANGELOG.md
@ -1,25 +1,31 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
# CUTLASS 4.x
|
# CUTLASS 4.x
|
||||||
## [4.0.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-05-09)
|
|
||||||
|
## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03)
|
||||||
|
|
||||||
### CuTe DSL
|
### CuTe DSL
|
||||||
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
||||||
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
||||||
- [DSL quick start](./media/docs/pythonDSL/quick_start.rst)
|
- [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
|
||||||
- [DSL Overview](./media/docs/pythonDSL/overview.rst)
|
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
|
||||||
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
||||||
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||||
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||||
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||||
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||||
|
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
|
||||||
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
||||||
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
||||||
|
- [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py)
|
||||||
|
- [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/jit_argument.py)
|
||||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||||
|
* API updates
|
||||||
|
- Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer``
|
||||||
|
|
||||||
### CUTLASS C++
|
### CUTLASS C++
|
||||||
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
* Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9
|
||||||
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling.
|
- 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9
|
||||||
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
|
* Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names.
|
||||||
- For example:
|
- For example:
|
||||||
|
|
||||||
@ -30,9 +36,25 @@
|
|||||||
- Added non-power-of-two tile sizes.
|
- Added non-power-of-two tile sizes.
|
||||||
- Improved performance for K-major scale factors.
|
- Improved performance for K-major scale factors.
|
||||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
|
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
|
||||||
|
* Support LSE output in Blackwell FMHA Forward kernel in example 77.
|
||||||
|
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||||
|
- Enable runtime datatype for Blackwell grouped GEMM. Profiler support is also added.
|
||||||
|
- Enable kernel parameter exploration for Blackwell grouped GEMM - raster_order, swizzle.
|
||||||
|
* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/).
|
||||||
|
* Add dynamic and preferred cluster support for convolution kernels.
|
||||||
|
* Support for Blackwell SM120 blockwise dense gemm in cutlass core library, as well as cutlass profiler.
|
||||||
|
* Fix profiler issues which cause no output or not supported error for some kernels.
|
||||||
|
* Optimization porting for BlockScaled collectives and kernel layers.
|
||||||
|
* New [Hopper FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||||
|
* CuTe changes:
|
||||||
|
- Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors.
|
||||||
|
- New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy.
|
||||||
|
- Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp).
|
||||||
|
- Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms.
|
||||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||||
|
|
||||||
|
|
||||||
# CUTLASS 3.x
|
# CUTLASS 3.x
|
||||||
|
|
||||||
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
|
## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03)
|
||||||
@ -82,7 +104,7 @@
|
|||||||
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
- Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels.
|
||||||
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
- Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance.
|
||||||
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
- Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration.
|
||||||
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
- More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss).
|
||||||
* Support `void` as the D element in sm100 kernel epilogues.
|
* Support `void` as the D element in sm100 kernel epilogues.
|
||||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||||
* Optimal code generation with CUDA toolkit versions 12.8U1.
|
* Optimal code generation with CUDA toolkit versions 12.8U1.
|
||||||
@ -101,7 +123,7 @@
|
|||||||
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
|
- [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
|
- [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp).
|
||||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||||
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
- [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||||
@ -139,11 +161,11 @@
|
|||||||
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
- A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||||
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
- A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||||
* Documentation updates:
|
* Documentation updates:
|
||||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-sm100-gemm-kernel).
|
- [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel).
|
||||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md)
|
- Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html)
|
||||||
- A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
- A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#target-architecture).
|
- Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture).
|
||||||
- Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper.
|
- Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper.
|
||||||
|
|
||||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||||
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
- [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||||
@ -156,7 +178,7 @@
|
|||||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
+ A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler).
|
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler).
|
||||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||||
|
|
||||||
@ -168,14 +190,14 @@
|
|||||||
+ [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
+ [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||||
+ [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
+ [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||||
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/cpp/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||||
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
- [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||||
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md).
|
- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html).
|
||||||
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||||
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
- A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||||
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
- A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
- A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||||
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper).
|
- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper).
|
||||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
|
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h)
|
||||||
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
|
- Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu).
|
||||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||||
@ -198,7 +220,7 @@
|
|||||||
- Support for residual add (beta != 0) in convolution kernels.
|
- Support for residual add (beta != 0) in convolution kernels.
|
||||||
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
- A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||||
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
|
- A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt).
|
||||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md).
|
- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html).
|
||||||
- Better support for MSVC as a host compiler.
|
- Better support for MSVC as a host compiler.
|
||||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||||
@ -206,13 +228,13 @@
|
|||||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||||
|
|
||||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
|
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md).
|
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html).
|
||||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
|
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp).
|
||||||
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
+ Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||||
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
+ [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||||
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
|
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
|
||||||
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
|
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
|
||||||
- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/cpp/README.md) in CuTe and CUTLASS 3.x
|
- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
|
||||||
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
|
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
|
||||||
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
|
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
|
||||||
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
||||||
@ -279,7 +301,7 @@
|
|||||||
* Updates and bugfixes from the community (thanks!)
|
* Updates and bugfixes from the community (thanks!)
|
||||||
|
|
||||||
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
||||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/cpp/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python).
|
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python).
|
||||||
* New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
* New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||||
* Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
* Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||||
|
|||||||
@ -175,11 +175,13 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 101 101a 120 120a)
|
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a)
|
||||||
|
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
||||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 101f 120f)
|
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f)
|
||||||
|
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||||
@ -344,6 +346,10 @@ if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES)
|
|||||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f)
|
||||||
|
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
|
set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace")
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
4
Doxyfile
4
Doxyfile
@ -1056,7 +1056,7 @@ HTML_STYLESHEET =
|
|||||||
# defined cascading style sheet that is included after the standard style sheets
|
# defined cascading style sheet that is included after the standard style sheets
|
||||||
# created by doxygen. Using this option one can overrule certain style aspects.
|
# created by doxygen. Using this option one can overrule certain style aspects.
|
||||||
# This is preferred over using HTML_STYLESHEET since it does not replace the
|
# This is preferred over using HTML_STYLESHEET since it does not replace the
|
||||||
# standard style sheet and is therefore more robust against future updates.
|
# standard style sheet and is therefor more robust against future updates.
|
||||||
# Doxygen will copy the style sheet file to the output directory. For an example
|
# Doxygen will copy the style sheet file to the output directory. For an example
|
||||||
# see the documentation.
|
# see the documentation.
|
||||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||||
@ -1940,7 +1940,7 @@ PREDEFINED =
|
|||||||
EXPAND_AS_DEFINED =
|
EXPAND_AS_DEFINED =
|
||||||
|
|
||||||
# If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will
|
# If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will
|
||||||
# remove all references to function-like macros that are alone on a line, have an
|
# remove all refrences to function-like macros that are alone on a line, have an
|
||||||
# all uppercase name, and do not end with a semicolon. Such function macros are
|
# all uppercase name, and do not end with a semicolon. Such function macros are
|
||||||
# typically used for boiler-plate code, and will confuse the parser if not
|
# typically used for boiler-plate code, and will confuse the parser if not
|
||||||
# removed.
|
# removed.
|
||||||
|
|||||||
61
README.md
61
README.md
@ -40,21 +40,22 @@ designs, and bringing optimized solutions into production.
|
|||||||
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
|
CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025.
|
||||||
|
|
||||||
To get started quickly - please refer :
|
To get started quickly - please refer :
|
||||||
- [CUTLASS C++ Quick Start Guide](./media/docs/cpp/quickstart.md).
|
- [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||||
- [CuTe DSL Quick Start Guide](./media/docs/pythonDSL/quick_start.rst).
|
- [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html).
|
||||||
|
|
||||||
# What's New in CUTLASS 4.0
|
# What's New in CUTLASS 4.0
|
||||||
|
|
||||||
## CuTe DSL
|
## CuTe DSL
|
||||||
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
* CuTe DSL, a Python DSL centered around CuTe's abstractions
|
||||||
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
- [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL)
|
||||||
- [DSL Quick Start](./media/docs/pythonDSL/quick_start.rst)
|
- [DSL Quick Start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html)
|
||||||
- [DSL Overview](./media/docs/pythonDSL/overview.rst)
|
- [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html)
|
||||||
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
* [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass)
|
||||||
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
* Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels
|
||||||
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
- [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py)
|
||||||
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
- [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py)
|
||||||
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
- [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py)
|
||||||
|
- [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py)
|
||||||
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
- [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py)
|
||||||
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
- [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py)
|
||||||
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
* [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks)
|
||||||
@ -71,13 +72,15 @@ To get started quickly - please refer :
|
|||||||
- Added non-power-of-two tile sizes.
|
- Added non-power-of-two tile sizes.
|
||||||
- Improved performance for K-major scale factors.
|
- Improved performance for K-major scale factors.
|
||||||
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
|
- The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions.
|
||||||
|
* Support LSE output in Blackwell FMHA Forward kernel.
|
||||||
|
* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support.
|
||||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||||
* Optimal code generation with CUDA toolkit versions 12.9.
|
* Optimal code generation with CUDA toolkit versions 12.9.
|
||||||
|
|
||||||
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||||
CUTLASS team is working on a fix.
|
CUTLASS team is working on a fix.
|
||||||
|
|
||||||
**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.**
|
**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.**
|
||||||
|
|
||||||
# Performance
|
# Performance
|
||||||
|
|
||||||
@ -119,7 +122,7 @@ Layouts can also be combined and manipulated via functional composition, on whic
|
|||||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||||
This greatly simplifies the design and improves code composability and readability.
|
This greatly simplifies the design and improves code composability and readability.
|
||||||
More documentation specific to CuTe can be found in its
|
More documentation specific to CuTe can be found in its
|
||||||
[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md).
|
[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html).
|
||||||
|
|
||||||
# Compatibility
|
# Compatibility
|
||||||
|
|
||||||
@ -202,7 +205,7 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
|||||||
compiled for Blackwell SM100 architecture with arch conditional features
|
compiled for Blackwell SM100 architecture with arch conditional features
|
||||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||||
|
|
||||||
Please refer to the [functionality documentation](./media/docs/cpp/functionality.md)
|
Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html)
|
||||||
for details on which kernels require which target architectures.
|
for details on which kernels require which target architectures.
|
||||||
|
|
||||||
# Documentation
|
# Documentation
|
||||||
@ -210,22 +213,22 @@ for details on which kernels require which target architectures.
|
|||||||
CUTLASS is described in the following documents and the accompanying
|
CUTLASS is described in the following documents and the accompanying
|
||||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||||
|
|
||||||
- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS
|
- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS
|
||||||
- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS
|
- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS
|
||||||
- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||||
- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||||
- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||||
- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||||
- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS
|
||||||
- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project
|
- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project
|
||||||
- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code
|
- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code
|
||||||
- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++
|
||||||
- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||||
- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory
|
- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory
|
||||||
- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||||
- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application
|
- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application
|
||||||
- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development
|
- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development
|
||||||
- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent
|
||||||
kernels in the same stream, and how it is used in CUTLASS.
|
kernels in the same stream, and how it is used in CUTLASS.
|
||||||
|
|
||||||
# Resources
|
# Resources
|
||||||
@ -245,7 +248,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th
|
|||||||
paths.
|
paths.
|
||||||
|
|
||||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||||
on your system.
|
on your system.
|
||||||
|
|
||||||
@ -290,7 +293,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
|
|||||||
and template concepts defined in the CUTLASS project.
|
and template concepts defined in the CUTLASS project.
|
||||||
|
|
||||||
A detailed explanation of the source code organization may be found in the
|
A detailed explanation of the source code organization may be found in the
|
||||||
[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below.
|
[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below.
|
||||||
|
|
||||||
## CUTLASS Template Library
|
## CUTLASS Template Library
|
||||||
|
|
||||||
@ -364,7 +367,7 @@ tools/
|
|||||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||||
|
|
||||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md).
|
Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html).
|
||||||
|
|
||||||
# Performance Profiling
|
# Performance Profiling
|
||||||
|
|
||||||
@ -580,9 +583,9 @@ reference_device: Passed
|
|||||||
|
|
||||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||||
- [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples)
|
- [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples)
|
||||||
- [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples)
|
- [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples)
|
||||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md)
|
- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html)
|
||||||
|
|
||||||
|
|
||||||
# About
|
# About
|
||||||
|
|||||||
@ -42,7 +42,6 @@
|
|||||||
#include "cute/algorithm/functional.hpp"
|
#include "cute/algorithm/functional.hpp"
|
||||||
#include "cute/atom/mma_atom.hpp"
|
#include "cute/atom/mma_atom.hpp"
|
||||||
#include "cute/algorithm/gemm.hpp"
|
#include "cute/algorithm/gemm.hpp"
|
||||||
#include "cute/tensor_predicate.hpp"
|
|
||||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||||
#include "cutlass/arch/grid_dependency_control.h"
|
#include "cutlass/arch/grid_dependency_control.h"
|
||||||
|
|
||||||
@ -288,7 +287,7 @@ struct CollectiveMma<
|
|||||||
constexpr int tma_alignment_bits = 128;
|
constexpr int tma_alignment_bits = 128;
|
||||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||||
auto [M,N,K,L] = problem_shape_MNKL;
|
auto [M,N,K,L] = problem_shape_MNKL;
|
||||||
|
|
||||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||||
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||||
@ -445,7 +444,7 @@ struct CollectiveMma<
|
|||||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||||
++k_tile_iter;
|
++k_tile_iter;
|
||||||
|
|
||||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||||
launch_dep_grids = true;
|
launch_dep_grids = true;
|
||||||
cutlass::arch::launch_dependent_grids();
|
cutlass::arch::launch_dependent_grids();
|
||||||
}
|
}
|
||||||
@ -453,7 +452,7 @@ struct CollectiveMma<
|
|||||||
// Advance smem_pipe_write
|
// Advance smem_pipe_write
|
||||||
++smem_pipe_write;
|
++smem_pipe_write;
|
||||||
}
|
}
|
||||||
if (!disable_gdc && !launch_dep_grids) {
|
if (!disable_gdc && !launch_dep_grids) {
|
||||||
cutlass::arch::launch_dependent_grids();
|
cutlass::arch::launch_dependent_grids();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -533,7 +532,7 @@ struct CollectiveMma<
|
|||||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||||
++k_tile_iter;
|
++k_tile_iter;
|
||||||
|
|
||||||
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) {
|
||||||
launch_dep_grids = true;
|
launch_dep_grids = true;
|
||||||
cutlass::arch::launch_dependent_grids();
|
cutlass::arch::launch_dependent_grids();
|
||||||
}
|
}
|
||||||
@ -541,7 +540,7 @@ struct CollectiveMma<
|
|||||||
// Advance smem_pipe_write
|
// Advance smem_pipe_write
|
||||||
++smem_pipe_write;
|
++smem_pipe_write;
|
||||||
}
|
}
|
||||||
if (!disable_gdc && !launch_dep_grids) {
|
if (!disable_gdc && !launch_dep_grids) {
|
||||||
cutlass::arch::launch_dependent_grids();
|
cutlass::arch::launch_dependent_grids();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -634,9 +633,9 @@ struct CollectiveMma<
|
|||||||
// Issue the epilogue waits
|
// Issue the epilogue waits
|
||||||
if (lane_predicate) {
|
if (lane_predicate) {
|
||||||
/* This helps avoid early exit of blocks in Cluster
|
/* This helps avoid early exit of blocks in Cluster
|
||||||
* Waits for all stages to either be released (all
|
* Waits for all stages to either be released (all
|
||||||
* Consumer UNLOCKs), or if the stage was never used
|
* Consumer UNLOCKs), or if the stage was never used
|
||||||
* then would just be acquired since the phase was
|
* then would just be acquired since the phase was
|
||||||
* still inverted from make_producer_start_state
|
* still inverted from make_producer_start_state
|
||||||
*/
|
*/
|
||||||
pipeline.producer_tail(smem_pipe_write);
|
pipeline.producer_tail(smem_pipe_write);
|
||||||
@ -854,7 +853,7 @@ struct CollectiveMma<
|
|||||||
k_tile_count -= prologue_mma_count;
|
k_tile_count -= prologue_mma_count;
|
||||||
|
|
||||||
smem_pipe_release.advance(k_tile_count);
|
smem_pipe_release.advance(k_tile_count);
|
||||||
|
|
||||||
// Wait on all GMMAs to complete
|
// Wait on all GMMAs to complete
|
||||||
warpgroup_wait<0>();
|
warpgroup_wait<0>();
|
||||||
|
|
||||||
|
|||||||
@ -133,7 +133,7 @@ using TP = _8;
|
|||||||
static constexpr int TP_ = TP{};
|
static constexpr int TP_ = TP{};
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||||
|
|
||||||
// Distributed GEMM tiling/sharding schedule
|
// Distributed GEMM tiling/sharding schedule
|
||||||
// Choices:
|
// Choices:
|
||||||
@ -252,7 +252,8 @@ HostTensorB tensor_B_arr[TP_];
|
|||||||
HostTensorD tensor_C_arr[TP_];
|
HostTensorD tensor_C_arr[TP_];
|
||||||
HostTensorD tensor_D_arr[TP_];
|
HostTensorD tensor_D_arr[TP_];
|
||||||
|
|
||||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||||
|
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
/// Testbed utility types
|
/// Testbed utility types
|
||||||
@ -345,7 +346,7 @@ struct Result {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \
|
||||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
/// GEMM setup and evaluation
|
/// GEMM setup and evaluation
|
||||||
@ -803,17 +804,18 @@ int run(Options &options) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) &&
|
||||||
|
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
int main(int argc, char const **args) {
|
int main(int argc, char const **args) {
|
||||||
|
|
||||||
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
|
// CUTLASS must be compiled with CUDA Toolkit 12.6 or newer to run this example
|
||||||
// and must have compute capability at least 90.
|
// and must have compute capability at least 90.
|
||||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
|
// Some necessary cuda graph APIs were only introduced in CUDA 12.6.
|
||||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) {
|
||||||
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
|
std::cerr << "This example requires CUDA 12.6 or newer." << std::endl;
|
||||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -857,11 +859,11 @@ int main(int argc, char const **args) {
|
|||||||
// Evaluate CUTLASS kernels
|
// Evaluate CUTLASS kernels
|
||||||
//
|
//
|
||||||
|
|
||||||
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)))
|
||||||
run(options);
|
run(options);
|
||||||
#else
|
#else
|
||||||
std::cerr
|
std::cerr
|
||||||
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.4 or later." << std::endl;
|
<< "This example must be compiled with `sm90a` and CUDA Toolkit 12.6 or later." << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@ -250,8 +250,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
|||||||
/// Testbed utility types
|
/// Testbed utility types
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
|
||||||
|
|
||||||
/// Result structure
|
/// Result structure
|
||||||
struct Result
|
struct Result
|
||||||
{
|
{
|
||||||
@ -518,7 +516,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
|
|||||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
arguments.scheduler.raster_order = options.raster;
|
arguments.scheduler.raster_order = options.raster_order;
|
||||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||||
|
|
||||||
@ -690,10 +688,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
|||||||
|
|
||||||
std::string raster = "Heuristic";
|
std::string raster = "Heuristic";
|
||||||
|
|
||||||
if (options.raster == RasterOrderOptions::AlongN) {
|
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||||
raster = "Along N";
|
raster = "Along N";
|
||||||
}
|
}
|
||||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||||
raster = "Along M";
|
raster = "Along M";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -747,7 +745,7 @@ int main(int argc, char const **args) {
|
|||||||
// Parse options
|
// Parse options
|
||||||
//
|
//
|
||||||
|
|
||||||
Options<RasterOrderOptions, ProblemShape> options;
|
Options<ProblemShape> options;
|
||||||
|
|
||||||
options.parse(argc, args);
|
options.parse(argc, args);
|
||||||
|
|
||||||
|
|||||||
@ -253,8 +253,6 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
|||||||
/// Testbed utility types
|
/// Testbed utility types
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams<Shape<int,int,int>>::RasterOrderOptions;
|
|
||||||
|
|
||||||
/// Result structure
|
/// Result structure
|
||||||
struct Result
|
struct Result
|
||||||
{
|
{
|
||||||
@ -523,7 +521,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha
|
|||||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
arguments.scheduler.raster_order = options.raster;
|
arguments.scheduler.raster_order = options.raster_order;
|
||||||
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
// The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8)
|
||||||
arguments.scheduler.max_swizzle_size = options.swizzle;
|
arguments.scheduler.max_swizzle_size = options.swizzle;
|
||||||
|
|
||||||
@ -699,10 +697,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true)
|
|||||||
|
|
||||||
std::string raster = "Heuristic";
|
std::string raster = "Heuristic";
|
||||||
|
|
||||||
if (options.raster == RasterOrderOptions::AlongN) {
|
if (options.raster_order == RasterOrderOptions::AlongN) {
|
||||||
raster = "Along N";
|
raster = "Along N";
|
||||||
}
|
}
|
||||||
else if (options.raster == RasterOrderOptions::AlongM) {
|
else if (options.raster_order == RasterOrderOptions::AlongM) {
|
||||||
raster = "Along M";
|
raster = "Along M";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -755,7 +753,7 @@ int main(int argc, char const **args) {
|
|||||||
// Parse options
|
// Parse options
|
||||||
//
|
//
|
||||||
|
|
||||||
Options<RasterOrderOptions, ProblemShape> options;
|
Options<ProblemShape> options;
|
||||||
|
|
||||||
options.parse(argc, args);
|
options.parse(argc, args);
|
||||||
|
|
||||||
|
|||||||
@ -30,10 +30,9 @@
|
|||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
|
|
||||||
// Command line options parsing
|
// Command line options parsing
|
||||||
template<typename _RasterOrderOptions, typename _ProblemShape>
|
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||||
|
template<typename _ProblemShape>
|
||||||
struct Options {
|
struct Options {
|
||||||
|
|
||||||
using RasterOrderOptions = _RasterOrderOptions;
|
|
||||||
using ProblemShape = _ProblemShape;
|
using ProblemShape = _ProblemShape;
|
||||||
|
|
||||||
bool help = false;
|
bool help = false;
|
||||||
@ -50,7 +49,7 @@ struct Options {
|
|||||||
int const m_alignment = 128;
|
int const m_alignment = 128;
|
||||||
int const n_alignment = 128;
|
int const n_alignment = 128;
|
||||||
|
|
||||||
RasterOrderOptions raster;
|
RasterOrderOptions raster_order;
|
||||||
int swizzle;
|
int swizzle;
|
||||||
|
|
||||||
// Parses the command line
|
// Parses the command line
|
||||||
@ -74,13 +73,13 @@ struct Options {
|
|||||||
cmd.get_cmd_line_argument("raster", raster_char);
|
cmd.get_cmd_line_argument("raster", raster_char);
|
||||||
|
|
||||||
if (raster_char == 'N' || raster_char == 'n') {
|
if (raster_char == 'N' || raster_char == 'n') {
|
||||||
raster = RasterOrderOptions::AlongN;
|
raster_order = RasterOrderOptions::AlongN;
|
||||||
}
|
}
|
||||||
else if (raster_char == 'M' || raster_char == 'm') {
|
else if (raster_char == 'M' || raster_char == 'm') {
|
||||||
raster = RasterOrderOptions::AlongM;
|
raster_order = RasterOrderOptions::AlongM;
|
||||||
}
|
}
|
||||||
else if (raster_char == 'H' || raster_char == 'h') {
|
else if (raster_char == 'H' || raster_char == 'h') {
|
||||||
raster = RasterOrderOptions::Heuristic;
|
raster_order = RasterOrderOptions::Heuristic;
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
cmd.get_cmd_line_argument("swizzle", swizzle, 1);
|
||||||
|
|||||||
@ -543,7 +543,7 @@ int run(Options &options) {
|
|||||||
|
|
||||||
int main(int argc, char const **args) {
|
int main(int argc, char const **args) {
|
||||||
|
|
||||||
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
// CUTLASS must be compiled with CUDA 12.8 Toolkit or newer to run this example
|
||||||
// and must have compute capability at least 100.
|
// and must have compute capability at least 100.
|
||||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||||
@ -560,7 +560,6 @@ int main(int argc, char const **args) {
|
|||||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parse options
|
// Parse options
|
||||||
//
|
//
|
||||||
|
|||||||
@ -237,7 +237,7 @@ cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
|||||||
/// Testbed utility types
|
/// Testbed utility types
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||||
// Command line options parsing
|
// Command line options parsing
|
||||||
struct Options {
|
struct Options {
|
||||||
|
|
||||||
|
|||||||
@ -300,7 +300,7 @@ auto make_iterator(T* ptr) {
|
|||||||
/// Testbed utility types
|
/// Testbed utility types
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||||
// Command line options parsing
|
// Command line options parsing
|
||||||
struct Options {
|
struct Options {
|
||||||
|
|
||||||
|
|||||||
@ -490,7 +490,7 @@ int run(Options &options)
|
|||||||
|
|
||||||
int main(int argc, char const **args) {
|
int main(int argc, char const **args) {
|
||||||
|
|
||||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||||
// and must have compute capability at least 90.
|
// and must have compute capability at least 90.
|
||||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||||
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
|
|||||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||||
|
|
||||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parse options
|
// Parse options
|
||||||
//
|
//
|
||||||
|
|||||||
@ -490,7 +490,7 @@ int run(Options &options)
|
|||||||
|
|
||||||
int main(int argc, char const **args) {
|
int main(int argc, char const **args) {
|
||||||
|
|
||||||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
|
// CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example
|
||||||
// and must have compute capability at least 90.
|
// and must have compute capability at least 90.
|
||||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||||
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||||
@ -503,11 +503,11 @@ int main(int argc, char const **args) {
|
|||||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||||
|
|
||||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parse options
|
// Parse options
|
||||||
//
|
//
|
||||||
|
|||||||
@ -499,11 +499,11 @@ int main(int argc, char const **args) {
|
|||||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||||
|
|
||||||
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
if (props.major != 10 && (props.minor != 0 || props.minor != 1)) {
|
||||||
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Parse options
|
// Parse options
|
||||||
//
|
//
|
||||||
|
|||||||
@ -117,15 +117,17 @@ struct Options {
|
|||||||
int q = 256;
|
int q = 256;
|
||||||
int k = 256;
|
int k = 256;
|
||||||
int d = 128;
|
int d = 128;
|
||||||
|
int warmup_iterations = 1;
|
||||||
int iterations = 3;
|
int iterations = 3;
|
||||||
|
int tensor_ring_buffers = 1;
|
||||||
bool verify = false;
|
bool verify = false;
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
|
|
||||||
bool causal = false;
|
bool causal = false;
|
||||||
bool residual = false;
|
bool residual = false;
|
||||||
bool varlen = false;
|
bool varlen = false;
|
||||||
|
bool persistent = false;
|
||||||
int sm_count = 0;
|
int sm_count = 0;
|
||||||
|
|
||||||
std::string kernel_filter;
|
std::string kernel_filter;
|
||||||
|
|
||||||
InitStyle init_style_q = InitStyle::kRandom;
|
InitStyle init_style_q = InitStyle::kRandom;
|
||||||
@ -189,10 +191,15 @@ struct Options {
|
|||||||
if (b == -1) b = 16384 / k;
|
if (b == -1) b = 16384 / k;
|
||||||
if (b == 0) b = 1;
|
if (b == 0) b = 1;
|
||||||
|
|
||||||
|
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations);
|
||||||
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
|
||||||
|
cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers);
|
||||||
|
|
||||||
verify = cmd.check_cmd_line_flag("verify");
|
verify = cmd.check_cmd_line_flag("verify");
|
||||||
verbose = cmd.check_cmd_line_flag("verbose");
|
verbose = cmd.check_cmd_line_flag("verbose");
|
||||||
varlen = cmd.check_cmd_line_flag("varlen");
|
varlen = cmd.check_cmd_line_flag("varlen");
|
||||||
|
persistent = cmd.check_cmd_line_flag("persistent");
|
||||||
|
|
||||||
std::string mask;
|
std::string mask;
|
||||||
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
cmd.get_cmd_line_argument<std::string>("mask", mask, "");
|
||||||
if (mask == "no" || mask == "") {
|
if (mask == "no" || mask == "") {
|
||||||
@ -210,7 +217,6 @@ struct Options {
|
|||||||
causal = false;
|
causal = false;
|
||||||
}
|
}
|
||||||
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count);
|
||||||
|
|
||||||
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q);
|
||||||
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q);
|
||||||
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q);
|
||||||
@ -235,10 +241,13 @@ struct Options {
|
|||||||
<< " --q=<int> Sets the Q extent\n"
|
<< " --q=<int> Sets the Q extent\n"
|
||||||
<< " --k=<int> Sets the K extent\n"
|
<< " --k=<int> Sets the K extent\n"
|
||||||
<< " --d=<int> Sets the D extentn"
|
<< " --d=<int> Sets the D extentn"
|
||||||
|
<< " --tensor_ring_buffers=<int> Sets the number of tensor ring buffers\n"
|
||||||
|
<< " --warmup_iterations=<int> Sets the warmup iterations\n"
|
||||||
<< " --iterations=<int> Benchmarking iterations\n"
|
<< " --iterations=<int> Benchmarking iterations\n"
|
||||||
<< " --verify Verify results\n"
|
<< " --verify Verify results\n"
|
||||||
<< " --verbose Print smem and execution time per kernel\n"
|
<< " --verbose Print smem and execution time per kernel\n"
|
||||||
<< " --mask=<no|residual|causal> Enables masking\n"
|
<< " --mask=<no|residual|causal> Enables masking\n"
|
||||||
|
<< " --persistent Enables persistent scheduler\n"
|
||||||
<< " --varlen Enables variable sequence length\n"
|
<< " --varlen Enables variable sequence length\n"
|
||||||
<< " B*Q and B*K become the total sequence length\n"
|
<< " B*Q and B*K become the total sequence length\n"
|
||||||
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
<< " and are split B-ways, alternatingly +10% and -10%\n"
|
||||||
@ -379,40 +388,55 @@ struct FwdRunner {
|
|||||||
StrideLSE stride_LSE;
|
StrideLSE stride_LSE;
|
||||||
uint64_t seed = 0;
|
uint64_t seed = 0;
|
||||||
|
|
||||||
DeviceAllocation<Element> block_Q;
|
struct DeviceBuffer {
|
||||||
DeviceAllocation<Element> block_K;
|
DeviceAllocation<Element> block_Q;
|
||||||
DeviceAllocation<Element> block_V;
|
DeviceAllocation<Element> block_K;
|
||||||
DeviceAllocation<ElementOut> block_O;
|
DeviceAllocation<Element> block_V;
|
||||||
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
DeviceAllocation<ElementOut> block_O;
|
||||||
DeviceAllocation<ElementOut> block_ref_O;
|
DeviceAllocation<ElementAccumulatorPV> block_LSE;
|
||||||
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
DeviceAllocation<ElementOut> block_ref_O;
|
||||||
|
DeviceAllocation<ElementAccumulatorPV> block_ref_LSE;
|
||||||
|
DeviceAllocation<int> device_cumulative_seqlen_q;
|
||||||
|
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
||||||
|
|
||||||
|
DeviceBuffer() = default;
|
||||||
|
DeviceBuffer(const DeviceBuffer&) = delete;
|
||||||
|
DeviceBuffer& operator=(const DeviceBuffer&) = delete;
|
||||||
|
|
||||||
|
size_t get_storage_size() const {
|
||||||
|
return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size()
|
||||||
|
+ block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size()
|
||||||
|
+ block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size()
|
||||||
|
+ device_cumulative_seqlen_kv.get_storage_size();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<DeviceBuffer>> buffers;
|
||||||
|
|
||||||
std::vector<int> cumulative_seqlen_q;
|
std::vector<int> cumulative_seqlen_q;
|
||||||
std::vector<int> cumulative_seqlen_kv;
|
std::vector<int> cumulative_seqlen_kv;
|
||||||
DeviceAllocation<int> device_cumulative_seqlen_q;
|
|
||||||
DeviceAllocation<int> device_cumulative_seqlen_kv;
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
//
|
//
|
||||||
bool verify(const ProblemShapeType& problem_shape) {
|
bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) {
|
||||||
Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()),
|
Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()),
|
||||||
select<0,2,3>(problem_shape),
|
select<0,2,3>(problem_shape),
|
||||||
stride_Q);
|
stride_Q);
|
||||||
|
|
||||||
Tensor mK = make_tensor(make_gmem_ptr(block_K.get()),
|
Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()),
|
||||||
select<1,2,3>(problem_shape),
|
select<1,2,3>(problem_shape),
|
||||||
stride_K);
|
stride_K);
|
||||||
|
|
||||||
Tensor mV = make_tensor(make_gmem_ptr(block_V.get()),
|
Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()),
|
||||||
select<1,2,3>(problem_shape),
|
select<1,2,3>(problem_shape),
|
||||||
stride_V);
|
stride_V);
|
||||||
|
|
||||||
Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()),
|
Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()),
|
||||||
select<0,2,3>(problem_shape),
|
select<0,2,3>(problem_shape),
|
||||||
stride_O);
|
stride_O);
|
||||||
|
|
||||||
Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()),
|
Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()),
|
||||||
select<0,3>(problem_shape),
|
select<0,3>(problem_shape),
|
||||||
stride_LSE);
|
stride_LSE);
|
||||||
|
|
||||||
@ -431,7 +455,7 @@ struct FwdRunner {
|
|||||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||||
double max_diff = 0;
|
double max_diff = 0;
|
||||||
double mean_diff = 0;
|
double mean_diff = 0;
|
||||||
reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff);
|
reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff);
|
||||||
|
|
||||||
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||||
if (! passed_O) {
|
if (! passed_O) {
|
||||||
@ -439,14 +463,13 @@ struct FwdRunner {
|
|||||||
<< " mean " << mean_diff << std::endl;
|
<< " mean " << mean_diff << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff);
|
reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff);
|
||||||
|
|
||||||
bool passed_LSE = true; // future work
|
bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
||||||
// bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh);
|
if ( ! passed_LSE) {
|
||||||
// if ( ! passed_LSE) {
|
std::cerr << "failed LSE: max diff " << max_diff
|
||||||
// std::cerr << "failed LSE: max diff " << max_diff
|
<< " mean " << mean_diff << std::endl;
|
||||||
// << " mean " << mean_diff << std::endl;
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
return passed_O && passed_LSE;
|
return passed_O && passed_LSE;
|
||||||
}
|
}
|
||||||
@ -559,50 +582,70 @@ struct FwdRunner {
|
|||||||
get<1,1>(stride_LSE) = 0;
|
get<1,1>(stride_LSE) = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
auto buffer_init_fn = [&](auto& buffer) {
|
||||||
block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||||
block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||||
block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0);
|
||||||
block_LSE.reset(size(shape_LSE));
|
buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0);
|
||||||
block_ref_O.reset(size(shape_QO));
|
buffer.block_LSE.reset(size(shape_LSE));
|
||||||
block_ref_LSE.reset(size(shape_LSE));
|
|
||||||
|
|
||||||
initialize_block(block_Q, seed + 2023, options.init_style_q);
|
initialize_block(buffer.block_Q, seed + 2023, options.init_style_q);
|
||||||
initialize_block(block_K, seed + 2022, options.init_style_k);
|
initialize_block(buffer.block_K, seed + 2022, options.init_style_k);
|
||||||
initialize_block(block_V, seed + 2021, options.init_style_v);
|
initialize_block(buffer.block_V, seed + 2021, options.init_style_v);
|
||||||
|
|
||||||
if ( ! cumulative_seqlen_q.empty()) {
|
if ( ! cumulative_seqlen_q.empty()) {
|
||||||
device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size());
|
||||||
device_cumulative_seqlen_q.copy_from_host(
|
buffer.device_cumulative_seqlen_q.copy_from_host(
|
||||||
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
cumulative_seqlen_q.data(), cumulative_seqlen_q.size());
|
||||||
}
|
}
|
||||||
if ( ! cumulative_seqlen_kv.empty()) {
|
if ( ! cumulative_seqlen_kv.empty()) {
|
||||||
device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size());
|
||||||
device_cumulative_seqlen_kv.copy_from_host(
|
buffer.device_cumulative_seqlen_kv.copy_from_host(
|
||||||
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||||
|
buffer_init_fn(*buffers.back());
|
||||||
|
|
||||||
|
int tensor_ring_buffers = options.tensor_ring_buffers;
|
||||||
|
for (int i = 1; i < tensor_ring_buffers; i++) {
|
||||||
|
buffers.push_back(std::make_unique<DeviceBuffer>());
|
||||||
|
buffer_init_fn(*buffers.back());
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (kIsVarlen) {
|
if constexpr (kIsVarlen) {
|
||||||
get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get();
|
get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get();
|
||||||
get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get();
|
get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
return problem_shape;
|
return problem_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) {
|
||||||
|
auto problem_shape_ = problem_shape;
|
||||||
|
if constexpr (kIsVarlen) {
|
||||||
|
get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get();
|
||||||
|
get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get();
|
||||||
|
}
|
||||||
|
typename Operation::Arguments arguments{
|
||||||
|
problem_shape_,
|
||||||
|
{ buffers[buffer_index]->block_Q.get(), stride_Q,
|
||||||
|
buffers[buffer_index]->block_K.get(), stride_K,
|
||||||
|
buffers[buffer_index]->block_V.get(), stride_V },
|
||||||
|
{ buffers[buffer_index]->block_O.get(), stride_O,
|
||||||
|
buffers[buffer_index]->block_LSE.get(), stride_LSE },
|
||||||
|
hw_info
|
||||||
|
};
|
||||||
|
return arguments;
|
||||||
|
}
|
||||||
|
|
||||||
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
|
||||||
|
|
||||||
ProblemShapeType problem_shape = initialize(options);
|
ProblemShapeType problem_shape = initialize(options);
|
||||||
|
|
||||||
typename Operation::Arguments arguments{
|
int buffer_index = 0;
|
||||||
problem_shape,
|
typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||||
{ block_Q.get(), stride_Q,
|
|
||||||
block_K.get(), stride_K,
|
|
||||||
block_V.get(), stride_V },
|
|
||||||
{ block_O.get(), stride_O,
|
|
||||||
block_LSE.get(), stride_LSE },
|
|
||||||
hw_info
|
|
||||||
};
|
|
||||||
|
|
||||||
Operation op;
|
Operation op;
|
||||||
|
|
||||||
@ -630,11 +673,21 @@ struct FwdRunner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run
|
// Run
|
||||||
status = op.run();
|
for (int i = 0; i < options.warmup_iterations; i++) {
|
||||||
if (status != cutlass::Status::kSuccess) {
|
status = op.run();
|
||||||
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
if (status != cutlass::Status::kSuccess) {
|
||||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: "
|
||||||
return example_result;
|
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||||
|
return example_result;
|
||||||
|
}
|
||||||
|
buffer_index = (buffer_index + 1) % buffers.size();
|
||||||
|
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||||
|
status = op.update(arguments, workspace.get());
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||||
|
<< std::endl;
|
||||||
|
return example_result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaError_t result = cudaDeviceSynchronize();
|
cudaError_t result = cudaDeviceSynchronize();
|
||||||
@ -672,6 +725,14 @@ struct FwdRunner {
|
|||||||
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
<< cudaGetErrorString(cudaGetLastError()) << std::endl;
|
||||||
return example_result;
|
return example_result;
|
||||||
}
|
}
|
||||||
|
buffer_index = (buffer_index + 1) % buffers.size();
|
||||||
|
arguments = get_arguments(problem_shape, hw_info, buffer_index);
|
||||||
|
status = op.update(arguments, workspace.get());
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: "
|
||||||
|
<< std::endl;
|
||||||
|
return example_result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -734,10 +795,10 @@ struct FwdRunner {
|
|||||||
// Verify that the result is correct
|
// Verify that the result is correct
|
||||||
bool passed = true;
|
bool passed = true;
|
||||||
if (options.verify) {
|
if (options.verify) {
|
||||||
passed = verify(problem_shape);
|
passed = verify(problem_shape, *buffers[0]);
|
||||||
if (passed) example_result.verified = true;
|
if (passed) example_result.verified = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!passed) {
|
if (!passed) {
|
||||||
std::cerr << "Reference check failed" << std::endl;
|
std::cerr << "Reference check failed" << std::endl;
|
||||||
return example_result;
|
return example_result;
|
||||||
@ -789,10 +850,14 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn
|
|||||||
|
|
||||||
using HeadDim = _128;
|
using HeadDim = _128;
|
||||||
|
|
||||||
// Persistent Tile Scheduler
|
if (options.persistent) {
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
// Persistent Tile Scheduler
|
||||||
// Individual Tile Scheduler
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
}
|
||||||
|
else {
|
||||||
|
// Individual Tile Scheduler
|
||||||
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -818,10 +883,14 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
|||||||
|
|
||||||
using HeadDim = _64;
|
using HeadDim = _64;
|
||||||
|
|
||||||
// Persistent Tile Scheduler
|
if (options.persistent) {
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
// Persistent Tile Scheduler
|
||||||
// Individual Tile Scheduler
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
}
|
||||||
|
else {
|
||||||
|
// Individual Tile Scheduler
|
||||||
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -845,10 +914,14 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf
|
|||||||
using HeadDim = _32;
|
using HeadDim = _32;
|
||||||
|
|
||||||
#ifdef FP8
|
#ifdef FP8
|
||||||
// Persistent Tile Scheduler
|
if (options.persistent) {
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
// Persistent Tile Scheduler
|
||||||
// Individual Tile Scheduler
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option<Tag::kIsPersistent, true_type>{});
|
||||||
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
}
|
||||||
|
else {
|
||||||
|
// Individual Tile Scheduler
|
||||||
|
run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option<Tag::kIsPersistent, false_type>{});
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel;
|
|||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
enum class InitStyle {
|
enum class InitStyle {
|
||||||
kOne, kLinearStride128, kLinearStride1, kRandom, kNone
|
kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone
|
||||||
};
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -98,6 +98,9 @@ struct Options {
|
|||||||
if (s == "r") {
|
if (s == "r") {
|
||||||
dst = InitStyle::kRandom;
|
dst = InitStyle::kRandom;
|
||||||
}
|
}
|
||||||
|
else if (s == "l") {
|
||||||
|
dst = InitStyle::kRandomLarge;
|
||||||
|
}
|
||||||
else if (s == "1") {
|
else if (s == "1") {
|
||||||
dst = InitStyle::kOne;
|
dst = InitStyle::kOne;
|
||||||
}
|
}
|
||||||
@ -203,6 +206,11 @@ void initialize_block(
|
|||||||
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
block.get(), block.size(), seed, (Element) -1, (Element) 1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case InitStyle::kRandomLarge: {
|
||||||
|
cutlass::reference::device::BlockFillRandomGaussian(
|
||||||
|
block.get(), block.size(), seed, (Element) -1, (Element) 100);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case InitStyle::kLinearStride1: {
|
case InitStyle::kLinearStride1: {
|
||||||
std::vector<Element> data(block.size());
|
std::vector<Element> data(block.size());
|
||||||
for (size_t i = 0; i < block.size() / 128; i ++) {
|
for (size_t i = 0; i < block.size() / 128; i ++) {
|
||||||
|
|||||||
@ -144,4 +144,23 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC
|
|||||||
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC)
|
||||||
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
|
# Add a target that builds all examples
|
||||||
|
add_custom_target(77_blackwell_fmha_all
|
||||||
|
DEPENDS
|
||||||
|
77_blackwell_fmha_fp8
|
||||||
|
77_blackwell_fmha_fp16
|
||||||
|
77_blackwell_fmha_gen_fp8
|
||||||
|
77_blackwell_fmha_gen_fp16
|
||||||
|
77_blackwell_mla_2sm_fp8
|
||||||
|
77_blackwell_mla_2sm_fp16
|
||||||
|
77_blackwell_mla_2sm_cpasync_fp8
|
||||||
|
77_blackwell_mla_2sm_cpasync_fp16
|
||||||
|
77_blackwell_mla_b2b_2sm_fp8
|
||||||
|
77_blackwell_mla_b2b_2sm_fp16
|
||||||
|
77_blackwell_fmha_bwd_fp8
|
||||||
|
77_blackwell_fmha_bwd_fp16
|
||||||
|
77_blackwell_fmha_bwd_sat_fp8
|
||||||
|
77_blackwell_fmha_bwd_sat_fp16
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@ -157,8 +157,8 @@ struct CausalMask : NoMask {
|
|||||||
TileShape const& tile_shape,
|
TileShape const& tile_shape,
|
||||||
ProblemSize const& problem_size) {
|
ProblemSize const& problem_size) {
|
||||||
|
|
||||||
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
|
||||||
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class BlkCoord, class TileShape, class ProblemSize>
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
|||||||
@ -42,7 +42,7 @@ template<
|
|||||||
class ElementAcc,
|
class ElementAcc,
|
||||||
class TileShape, // Q, D, _
|
class TileShape, // Q, D, _
|
||||||
class StrideO, // Q, D, B
|
class StrideO, // Q, D, B
|
||||||
class StrideLSE // Q, B
|
class StrideLSE_ // Q, B
|
||||||
>
|
>
|
||||||
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
||||||
|
|
||||||
@ -54,6 +54,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
|
||||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
|
||||||
using SmemLayoutO_ = SmemLayoutO;
|
using SmemLayoutO_ = SmemLayoutO;
|
||||||
|
using StrideLSE = StrideLSE_;
|
||||||
|
|
||||||
struct TensorStorage {
|
struct TensorStorage {
|
||||||
|
|
||||||
@ -79,6 +80,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
TMA_O tma_store_o;
|
TMA_O tma_store_o;
|
||||||
|
|
||||||
|
ElementAcc* ptr_LSE;
|
||||||
|
StrideLSE dLSE;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class ProblemShape>
|
template<class ProblemShape>
|
||||||
@ -110,7 +114,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
);
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
tma_store_o
|
tma_store_o,
|
||||||
|
args.ptr_LSE,
|
||||||
|
args.dLSE
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,6 +125,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
|
|||||||
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const Params& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
|
||||||
|
|
||||||
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE auto
|
||||||
store(
|
store(
|
||||||
|
|||||||
@ -531,7 +531,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||||
|
|
||||||
// Each thread owns a single row
|
// Each thread owns a single row
|
||||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
|
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
|
||||||
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
||||||
|
|
||||||
@ -613,7 +613,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
|
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
|
||||||
|
|
||||||
const int kReleasePipeCount = 10; // must be multiple of 2
|
const int kReleasePipeCount = 10; // must be multiple of 2
|
||||||
|
|
||||||
order_s.wait();
|
order_s.wait();
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
}
|
}
|
||||||
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
|
||||||
|
|
||||||
|
|
||||||
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
|
||||||
order_s.arrive();
|
order_s.arrive();
|
||||||
}
|
}
|
||||||
@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
|
||||||
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
|
||||||
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
|
||||||
|
|
||||||
row_sum = local_row_sum;
|
row_sum = local_row_sum;
|
||||||
|
|
||||||
if (final_call) {
|
if (final_call) {
|
||||||
@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
// good values would be either 32 or 64
|
// good values would be either 32 or 64
|
||||||
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
|
||||||
|
|
||||||
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
|
||||||
|
|
||||||
typename CollectiveMmaPV::TiledMma mma;
|
typename CollectiveMmaPV::TiledMma mma;
|
||||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||||
Tensor tOsO = mma.get_slice(0).partition_C(sO);
|
Tensor tOsO = mma.get_slice(0).partition_C(sO);
|
||||||
|
|
||||||
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
|
|
||||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
|
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
|
||||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||||
|
|
||||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
|
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
|
||||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
|
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
|
||||||
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
|
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
|
||||||
@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
|
||||||
|
|
||||||
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
|
||||||
|
|
||||||
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
|
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
|
||||||
|
|
||||||
#ifndef ONLY_SOFTMAX
|
#ifndef ONLY_SOFTMAX
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int j = 0; j < size(tTMrO); j += 2) {
|
for (int j = 0; j < size(tTMrO); j += 2) {
|
||||||
@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
// good values would be either 32 or 64
|
// good values would be either 32 or 64
|
||||||
const int kCorrectionTileSize = 16;
|
const int kCorrectionTileSize = 16;
|
||||||
|
|
||||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
|
||||||
|
|
||||||
typename CollectiveMmaPV::TiledMma mma;
|
typename CollectiveMmaPV::TiledMma mma;
|
||||||
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
|
||||||
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
|
||||||
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
Tensor tOcO = mma.get_slice(0).partition_C(cO);
|
||||||
|
|
||||||
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
|
||||||
|
|
||||||
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
|
||||||
|
|
||||||
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
|
||||||
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
|
||||||
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
|
||||||
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
|
||||||
|
|
||||||
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
|
||||||
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
|
||||||
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
|
||||||
@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
float2 scale_f32x2 = make_float2(scale, scale);
|
float2 scale_f32x2 = make_float2(scale, scale);
|
||||||
|
|
||||||
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
|
||||||
|
|
||||||
auto copy_in = [&](int i) {
|
auto copy_in = [&](int i) {
|
||||||
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
|
||||||
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
|
||||||
@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class BlkCoord, class ProblemShape, class TensorStorageEpi>
|
template<class BlkCoord, class ProblemShape, class TensorStorageEpi, class CollectiveEpilogue>
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE auto
|
||||||
correction(
|
correction(
|
||||||
BlkCoord const& blk_coord,
|
BlkCoord const& blk_coord,
|
||||||
@ -951,7 +951,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
|
||||||
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
|
||||||
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
|
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
|
||||||
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) {
|
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
|
||||||
|
CollectiveEpilogue& epilogue) {
|
||||||
|
|
||||||
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
|
||||||
|
|
||||||
@ -961,7 +962,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
|
|
||||||
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
|
||||||
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
|
||||||
|
|
||||||
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||||
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
|
||||||
|
|
||||||
@ -1060,13 +1061,25 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
// F2FP
|
// F2FP
|
||||||
// store to smem
|
// store to smem
|
||||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
|
||||||
|
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE);
|
||||||
|
|
||||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
|
||||||
|
|
||||||
|
if (epilogue.params.ptr_LSE != nullptr) {
|
||||||
|
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
|
||||||
|
|
||||||
|
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||||
|
|
||||||
|
if (row_idx < get<0>(problem_shape)) {
|
||||||
|
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cutlass::arch::fence_view_async_tmem_load();
|
cutlass::arch::fence_view_async_tmem_load();
|
||||||
|
|
||||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||||
++pipeline_o_consumer_state;
|
++pipeline_o_consumer_state;
|
||||||
|
|
||||||
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
pipeline_epi.producer_commit(pipeline_epi_producer_state);
|
||||||
++pipeline_epi_producer_state;
|
++pipeline_epi_producer_state;
|
||||||
|
|
||||||
@ -1083,6 +1096,16 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
|||||||
|
|
||||||
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
|
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
|
||||||
|
|
||||||
|
if (epilogue.params.ptr_LSE != nullptr) {
|
||||||
|
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
|
||||||
|
|
||||||
|
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
|
||||||
|
|
||||||
|
if (row_idx < get<0>(problem_shape)) {
|
||||||
|
gLSE(row_idx, get<2>(blk_coord)) = lse;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cutlass::arch::fence_view_async_tmem_load();
|
cutlass::arch::fence_view_async_tmem_load();
|
||||||
|
|
||||||
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
pipeline_o.consumer_release(pipeline_o_consumer_state);
|
||||||
|
|||||||
@ -118,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
using TensorStrideContiguousK = Stride<int, _1, Stride<int, int>>;
|
||||||
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
using TensorStrideContiguousMN = Stride<_1, int, Stride<int, int>>;
|
||||||
|
|
||||||
// compute S
|
// compute S
|
||||||
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
|
||||||
@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
|
make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc),
|
||||||
SmemLayoutDQ{}(_, _, _0{})
|
SmemLayoutDQ{}(_, _, _0{})
|
||||||
);
|
);
|
||||||
|
|
||||||
return Params{
|
return Params{
|
||||||
args.problem_shape,
|
args.problem_shape,
|
||||||
args.mainloop,
|
args.mainloop,
|
||||||
@ -452,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
|
ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{});
|
||||||
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
|
ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{});
|
||||||
|
|
||||||
auto tSTgK = cta_mma_kq.partition_A(gK);
|
auto tSTgK = cta_mma_kq.partition_A(gK);
|
||||||
auto tSTgQ = cta_mma_kq.partition_B(gQ);
|
auto tSTgQ = cta_mma_kq.partition_B(gQ);
|
||||||
auto tDPTgV = cta_mma_vdo.partition_A(gV);
|
auto tDPTgV = cta_mma_vdo.partition_A(gV);
|
||||||
@ -477,7 +477,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
|
group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO));
|
||||||
|
|
||||||
// set up lse and sum_odo
|
// set up lse and sum_odo
|
||||||
|
|
||||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||||
|
|
||||||
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state);
|
||||||
@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load Q
|
// load Q
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -520,7 +520,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
&mLSE(gmem_idx, blk_coord_batch),
|
&mLSE(gmem_idx, blk_coord_batch),
|
||||||
gmem_idx < Q
|
gmem_idx < Q
|
||||||
);
|
);
|
||||||
|
|
||||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||||
++pipeline_load_compute_lse_producer_state;
|
++pipeline_load_compute_lse_producer_state;
|
||||||
|
|
||||||
@ -529,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||||
|
|
||||||
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
|
pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV);
|
||||||
|
|
||||||
// load V
|
// load V
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
@ -540,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load dO
|
// load dO
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -573,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state);
|
||||||
|
|
||||||
// load Q
|
// load Q
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask),
|
||||||
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -584,7 +584,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
++pipeline_load_mma_q_producer_state;
|
++pipeline_load_mma_q_producer_state;
|
||||||
|
|
||||||
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
|
pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state);
|
||||||
|
|
||||||
// load LSE
|
// load LSE
|
||||||
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4;
|
||||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||||
@ -593,15 +593,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
&mLSE(gmem_idx, blk_coord_batch),
|
&mLSE(gmem_idx, blk_coord_batch),
|
||||||
gmem_idx < Q
|
gmem_idx < Q
|
||||||
);
|
);
|
||||||
|
|
||||||
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||||
++pipeline_load_compute_lse_producer_state;
|
++pipeline_load_compute_lse_producer_state;
|
||||||
|
|
||||||
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
|
pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state);
|
||||||
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state);
|
||||||
|
|
||||||
// load dO
|
// load dO
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
cute::copy(
|
cute::copy(
|
||||||
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask),
|
||||||
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch),
|
||||||
@ -612,7 +612,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
++pipeline_load_mma_do_producer_state;
|
++pipeline_load_mma_do_producer_state;
|
||||||
|
|
||||||
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
|
pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state);
|
||||||
|
|
||||||
// load sum_OdO
|
// load sum_OdO
|
||||||
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4;
|
||||||
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4;
|
||||||
@ -621,7 +621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
&mSumOdO(gmem_idx, blk_coord_batch),
|
&mSumOdO(gmem_idx, blk_coord_batch),
|
||||||
gmem_idx < Q
|
gmem_idx < Q
|
||||||
);
|
);
|
||||||
|
|
||||||
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||||
++pipeline_load_compute_sum_odo_producer_state;
|
++pipeline_load_compute_sum_odo_producer_state;
|
||||||
|
|
||||||
@ -639,23 +639,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
int iter_count,
|
int iter_count,
|
||||||
MainloopArguments const& mainloop_args,
|
MainloopArguments const& mainloop_args,
|
||||||
TensorStorage& shared_tensors,
|
TensorStorage& shared_tensors,
|
||||||
PipelineLoadMmaQ& pipeline_load_mma_q,
|
PipelineLoadMmaQ& pipeline_load_mma_q,
|
||||||
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state,
|
||||||
PipelineLoadMmaDO& pipeline_load_mma_do,
|
PipelineLoadMmaDO& pipeline_load_mma_do,
|
||||||
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
|
typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state,
|
||||||
PipelineMmaComputeS& pipeline_mma_compute_s,
|
PipelineMmaComputeS& pipeline_mma_compute_s,
|
||||||
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
|
typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state,
|
||||||
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
PipelineMmaComputeDP& pipeline_mma_compute_dp,
|
||||||
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
|
typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state,
|
||||||
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
|
||||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
|
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state,
|
||||||
PipelineComputeMmaP& pipeline_compute_mma_p,
|
PipelineComputeMmaP& pipeline_compute_mma_p,
|
||||||
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
|
typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state,
|
||||||
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
PipelineComputeMmaDS& pipeline_compute_mma_ds,
|
||||||
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
|
typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state,
|
||||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) {
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
|
|
||||||
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
|
||||||
@ -685,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
|
Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{});
|
||||||
tDVrP.data() = TmemAllocation::kP;
|
tDVrP.data() = TmemAllocation::kP;
|
||||||
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
|
Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT);
|
||||||
|
|
||||||
TiledMmaKQ tiled_mma_kq;
|
TiledMmaKQ tiled_mma_kq;
|
||||||
TiledMmaVDO tiled_mma_vdo;
|
TiledMmaVDO tiled_mma_vdo;
|
||||||
TiledMmaDSK tiled_mma_dsk;
|
TiledMmaDSK tiled_mma_dsk;
|
||||||
@ -923,6 +923,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
TensorC const& coord,
|
TensorC const& coord,
|
||||||
TensorShape const& tensor_shape) {
|
TensorShape const& tensor_shape) {
|
||||||
|
|
||||||
|
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
|
||||||
|
|
||||||
auto copy_op = make_cotiled_copy(
|
auto copy_op = make_cotiled_copy(
|
||||||
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
|
||||||
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
|
make_layout(make_shape(_1{}, Int<sizeof(uint128_t) / sizeof(Element)>{})),
|
||||||
@ -930,21 +932,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
);
|
);
|
||||||
auto thr_copy = copy_op.get_slice(_0{});
|
auto thr_copy = copy_op.get_slice(_0{});
|
||||||
|
|
||||||
auto tCg = thr_copy.partition_D(gmem);
|
Tensor tCg = thr_copy.partition_D(gmem);
|
||||||
auto tCr = thr_copy.partition_S(quantize(regs));
|
Tensor tCr = thr_copy.partition_S(quantize(regs));
|
||||||
auto tCc = thr_copy.partition_D(coord);
|
Tensor tPc = thr_copy.partition_D(preds);
|
||||||
|
|
||||||
constexpr int R = decltype(tCr.layout())::rank;
|
copy_if(copy_op, tPc, tCr, tCg);
|
||||||
auto tCg_v = group_modes<1, R>(tCg);
|
|
||||||
auto tCr_v = group_modes<1, R>(tCr);
|
|
||||||
auto tCc_v = group_modes<1, R>(tCc);
|
|
||||||
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
|
|
||||||
|
|
||||||
for (int i = 0; i < size(tCp_v); ++i) {
|
|
||||||
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1073,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv,
|
||||||
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) {
|
||||||
|
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
|
|
||||||
// in tmem, S & P overlap
|
// in tmem, S & P overlap
|
||||||
@ -1114,7 +1106,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
|
Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST));
|
||||||
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
Tensor tTR_rST = make_tensor<ElementAcc>(shape(tTR_cST));
|
||||||
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST));
|
||||||
|
|
||||||
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT);
|
||||||
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
|
Tensor tTR_cDPT = split_wg(tTR_cDPT_p);
|
||||||
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
|
Tensor tTR_rDPT = make_tensor<ElementAcc>(shape(tTR_cDPT));
|
||||||
@ -1152,20 +1144,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
fn(cute::false_type{});
|
fn(cute::false_type{});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
|
dispatch_bool(std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> &&
|
||||||
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
|
warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) {
|
||||||
|
|
||||||
// compute P = softmax(S, LSE)
|
// compute P = softmax(S, LSE)
|
||||||
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
|
cute::copy(tiled_t2r, tTR_tST, tTR_rST);
|
||||||
|
|
||||||
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
|
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask, Mask> && decltype(is_causal_masked_tile)::value) {
|
||||||
Mask{}.apply_mask(tTR_rST, [&](int i) {
|
Mask{}.apply_mask(tTR_rST, [&](int i) {
|
||||||
auto c_transpose = tTR_cST(i);
|
auto c_transpose = tTR_cST(i);
|
||||||
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
|
return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{});
|
||||||
}, problem_shape);
|
}, problem_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
|
ElementAcc log2_e = static_cast<ElementAcc>(M_LOG2E);
|
||||||
float2 softmax_scale_log2_e;
|
float2 softmax_scale_log2_e;
|
||||||
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
|
softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e;
|
||||||
@ -1184,16 +1176,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tTR_rST(i) = ::exp2f(out.x);
|
tTR_rST(i) = ::exp2f(out.x);
|
||||||
tTR_rST(i+1) = ::exp2f(out.y);
|
tTR_rST(i+1) = ::exp2f(out.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tRT_rST = quantize(tTR_rST);
|
auto tRT_rST = quantize(tTR_rST);
|
||||||
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
|
auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST));
|
||||||
|
|
||||||
cutlass::arch::fence_view_async_tmem_load();
|
cutlass::arch::fence_view_async_tmem_load();
|
||||||
cutlass::arch::NamedBarrier(
|
cutlass::arch::NamedBarrier(
|
||||||
kNumComputeWarps * NumThreadsPerWarp,
|
kNumComputeWarps * NumThreadsPerWarp,
|
||||||
cutlass::arch::ReservedNamedBarriers::TransformBarrier
|
cutlass::arch::ReservedNamedBarriers::TransformBarrier
|
||||||
).arrive_and_wait();
|
).arrive_and_wait();
|
||||||
|
|
||||||
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
|
cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -1293,9 +1285,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
|
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
|
||||||
PipelineReduceTmaStore& pipeline_reduce_tma_store,
|
PipelineReduceTmaStore& pipeline_reduce_tma_store,
|
||||||
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
|
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) {
|
||||||
|
|
||||||
using X = Underscore;
|
using X = Underscore;
|
||||||
|
|
||||||
auto [Q, K, D, HB] = problem_shape;
|
auto [Q, K, D, HB] = problem_shape;
|
||||||
|
|
||||||
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord;
|
||||||
@ -1307,7 +1299,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
tDQtDQ.data() = TmemAllocation::kDQ;
|
tDQtDQ.data() = TmemAllocation::kDQ;
|
||||||
|
|
||||||
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB));
|
||||||
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{})
|
auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<X, _1, _1>{})
|
||||||
(_, _, _, _0{}, blk_coord_batch);
|
(_, _, _, _0{}, blk_coord_batch);
|
||||||
|
|
||||||
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{}));
|
||||||
@ -1376,7 +1368,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
iter_index += 1;
|
iter_index += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
|
||||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||||
@ -1561,7 +1553,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
|
typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state;
|
||||||
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
|
typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state;
|
||||||
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
|
typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state;
|
||||||
|
|
||||||
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
|
auto pipeline_load_mma_q_producer_state = make_producer_start_state<decltype(pipeline_load_mma_q)>();
|
||||||
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
|
auto pipeline_load_mma_do_producer_state = make_producer_start_state<decltype(pipeline_load_mma_do)>();
|
||||||
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
|
auto pipeline_load_compute_lse_producer_state = make_producer_start_state<decltype(pipeline_load_compute_lse)>();
|
||||||
@ -1587,7 +1579,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
if (role == WarpRole::Load) {
|
if (role == WarpRole::Load) {
|
||||||
warpgroup_reg_set<RegisterAllocation::kLoad>();
|
warpgroup_reg_set<RegisterAllocation::kLoad>();
|
||||||
|
|
||||||
load(
|
load(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@ -1596,7 +1588,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
params.mainloop,
|
params.mainloop,
|
||||||
params.mainloop_params,
|
params.mainloop_params,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
pipeline_load_mma_q, pipeline_load_mma_q_producer_state,
|
||||||
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
|
pipeline_load_mma_do, pipeline_load_mma_do_producer_state,
|
||||||
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
|
pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state,
|
||||||
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
|
pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state
|
||||||
@ -1608,7 +1600,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
|
|
||||||
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
|
||||||
mma(
|
mma(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@ -1616,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
iter_count,
|
iter_count,
|
||||||
params.mainloop,
|
params.mainloop,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
pipeline_load_mma_q, pipeline_load_mma_q_consumer_state,
|
||||||
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
|
pipeline_load_mma_do, pipeline_load_mma_do_consumer_state,
|
||||||
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
|
pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state,
|
||||||
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
|
pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state,
|
||||||
@ -1629,7 +1621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
else if (role == WarpRole::Compute) {
|
else if (role == WarpRole::Compute) {
|
||||||
warpgroup_reg_set<RegisterAllocation::kCompute>();
|
warpgroup_reg_set<RegisterAllocation::kCompute>();
|
||||||
|
|
||||||
compute(
|
compute(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@ -1660,7 +1652,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
else if (role == WarpRole::Reduce) {
|
else if (role == WarpRole::Reduce) {
|
||||||
warpgroup_reg_set<RegisterAllocation::kReduce>();
|
warpgroup_reg_set<RegisterAllocation::kReduce>();
|
||||||
|
|
||||||
reduce(
|
reduce(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@ -1677,9 +1669,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
warpgroup_reg_set<RegisterAllocation::kEmpty>();
|
warpgroup_reg_set<RegisterAllocation::kEmpty>();
|
||||||
|
|
||||||
/* no-op */
|
/* no-op */
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
|||||||
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
|
||||||
|
|
||||||
CollectiveMainloop mainloop;
|
CollectiveMainloop mainloop;
|
||||||
CollectiveEpilogue epilogue;
|
CollectiveEpilogue epilogue{params.epilogue};
|
||||||
|
|
||||||
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
|
||||||
warpgroup_reg_set<NumRegsSoftmax>();
|
warpgroup_reg_set<NumRegsSoftmax>();
|
||||||
@ -407,7 +407,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
|
|||||||
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
pipeline_s0_corr, pipeline_s0_corr_consumer_state,
|
||||||
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
pipeline_s1_corr, pipeline_s1_corr_consumer_state,
|
||||||
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
pipeline_mma_corr, pipeline_mma_corr_consumer_state,
|
||||||
pipeline_corr_epi, pipeline_corr_epi_producer_state
|
pipeline_corr_epi, pipeline_corr_epi_producer_state,
|
||||||
|
epilogue
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel {
|
|||||||
ElementAcc sum_lse = 0;
|
ElementAcc sum_lse = 0;
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||||
sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max);
|
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel {
|
|||||||
|
|
||||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||||
|
|
||||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + params.scale * lse_max;
|
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||||
gLSE(0) = global_lse;
|
gLSE(0) = global_lse;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel(
|
|||||||
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast<typename TensorO::value_type>(acc * scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0 && mLSE.data() != nullptr) {
|
||||||
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -75,6 +75,8 @@ struct DeviceAllocation {
|
|||||||
|
|
||||||
size_t size() const { return size_; }
|
size_t size() const { return size_; }
|
||||||
|
|
||||||
|
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
|
||||||
|
|
||||||
void copy_from_host(const T* ptr, size_t sz) {
|
void copy_from_host(const T* ptr, size_t sz) {
|
||||||
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault);
|
||||||
assert(ret == cudaSuccess);
|
assert(ret == cudaSuccess);
|
||||||
|
|||||||
@ -280,7 +280,7 @@ auto make_iterator(T* ptr) {
|
|||||||
/// Testbed utility types
|
/// Testbed utility types
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||||
// Command line options parsing
|
// Command line options parsing
|
||||||
struct Options {
|
struct Options {
|
||||||
|
|
||||||
|
|||||||
@ -133,7 +133,7 @@ using TP = _8;
|
|||||||
static constexpr int TP_ = TP{};
|
static constexpr int TP_ = TP{};
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||||
|
|
||||||
// Distributed GEMM tiling/sharding schedule
|
// Distributed GEMM tiling/sharding schedule
|
||||||
// Choices:
|
// Choices:
|
||||||
@ -254,7 +254,8 @@ HostTensorB tensor_B_arr[TP_];
|
|||||||
HostTensorD tensor_C_arr[TP_];
|
HostTensorD tensor_C_arr[TP_];
|
||||||
HostTensorD tensor_D_arr[TP_];
|
HostTensorD tensor_D_arr[TP_];
|
||||||
|
|
||||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||||
|
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
/// Testbed utility types
|
/// Testbed utility types
|
||||||
@ -347,7 +348,7 @@ struct Result {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \
|
||||||
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
(__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
/// GEMM setup and evaluation
|
/// GEMM setup and evaluation
|
||||||
@ -805,17 +806,16 @@ int run(Options &options) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) &&
|
||||||
|
// (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
int main(int argc, char const **args) {
|
int main(int argc, char const **args) {
|
||||||
|
|
||||||
// CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example
|
// CUTLASS must be compiled with CUDA Toolkit 12.8 or newer to run Blackwell kernels.
|
||||||
// and must have compute capability at least 90.
|
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
|
||||||
// Some necessary cuda graph APIs were only introduced in CUDA 12.4.
|
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
|
||||||
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) {
|
|
||||||
std::cerr << "This example requires CUDA 12.4 or newer." << std::endl;
|
|
||||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -861,11 +861,11 @@ int main(int argc, char const **args) {
|
|||||||
// Evaluate CUTLASS kernels
|
// Evaluate CUTLASS kernels
|
||||||
//
|
//
|
||||||
|
|
||||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)))
|
||||||
run(options);
|
run(options);
|
||||||
#else
|
#else
|
||||||
std::cerr
|
std::cerr
|
||||||
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.4 or later." << std::endl;
|
<< "This example must be compiled with `sm100a` and CUDA Toolkit 12.8 or later." << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@ -14,8 +14,8 @@ cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1
|
|||||||
### Minimum software
|
### Minimum software
|
||||||
|
|
||||||
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
|
Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
|
||||||
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
|
This example specifically requires CUDA Toolkit 12.8 or newer, since that is the first version
|
||||||
CUDA graph APIs.
|
supporting the Blackwell architecture.
|
||||||
|
|
||||||
### Hardware / driver settings
|
### Hardware / driver settings
|
||||||
|
|
||||||
|
|||||||
1192
examples/88_hopper_fmha/88_hopper_fmha.cu
Normal file
1192
examples/88_hopper_fmha/88_hopper_fmha.cu
Normal file
File diff suppressed because it is too large
Load Diff
50
examples/88_hopper_fmha/CMakeLists.txt
Normal file
50
examples/88_hopper_fmha/CMakeLists.txt
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
cutlass_example_add_executable(
|
||||||
|
88_hopper_fmha
|
||||||
|
88_hopper_fmha.cu
|
||||||
|
)
|
||||||
|
|
||||||
|
if(NOT WIN32 AND NOT CUTLASS_CLANG_HOST_COMPILE)
|
||||||
|
|
||||||
|
set_property(
|
||||||
|
SOURCE 88_hopper_fmha.cu
|
||||||
|
PROPERTY COMPILE_FLAGS "--use_fast_math"
|
||||||
|
)
|
||||||
|
|
||||||
|
cutlass_example_add_executable(
|
||||||
|
88_hopper_fmha_fp8
|
||||||
|
88_hopper_fmha.cu
|
||||||
|
)
|
||||||
|
|
||||||
|
target_compile_definitions(
|
||||||
|
88_hopper_fmha_fp8
|
||||||
|
PRIVATE FP8)
|
||||||
|
|
||||||
|
endif()
|
||||||
77
examples/88_hopper_fmha/README.md
Normal file
77
examples/88_hopper_fmha/README.md
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# CUTLASS Hopper FMHA Example
|
||||||
|
|
||||||
|
This sample showcases how to implement fused multi-head attention (FMHA) using
|
||||||
|
CUTLASS for the NVIDIA Hopper architecture. At its heart, the forward pass of
|
||||||
|
FMHA is a GEMM-online softmax-GEMM fusion, whereas the backward pass is a slightly
|
||||||
|
more complex structure (basically, a GEMM-softmax-2xGEMM-2xGEMM fusion).
|
||||||
|
For more information please refer to the [Flash Attention 3 paper](https://arxiv.org/abs/2407.08608).
|
||||||
|
|
||||||
|
The forward pass kernel supports head dims 32, 64, 128, and 256 for fp16 and bf16 input data types,
|
||||||
|
and head dims 128, and 256 for fp8.
|
||||||
|
All kernels use the Tensor Memory Accelerator for loads.
|
||||||
|
Kernels with head dims 128 and 256 have warp-specialized cooperative schedules.
|
||||||
|
|
||||||
|
Backward pass kernels (fp16 only) support head dims 32, 64, and 128, and all support
|
||||||
|
warp-specialized cooperative schedules.
|
||||||
|
|
||||||
|
## Customization
|
||||||
|
|
||||||
|
### Mask Fusion
|
||||||
|
|
||||||
|
Similar to the [Blackwell FMHA example](../77_blackwell_fmha/README.md), attention masks such as
|
||||||
|
causal masking can be fused into the kernel. To modify the code for such fusions,
|
||||||
|
`collective/fmha_fusion.hpp` provides the easiest customization point.
|
||||||
|
The `before_softmax` function is called with the accumulator of the first GEMM and the logical
|
||||||
|
positions of those elements. It is well-suited for applying masks or activations.
|
||||||
|
|
||||||
|
### MHA Variants
|
||||||
|
|
||||||
|
Using CuTe, it is easy to represent the various attention variants.
|
||||||
|
Where regular multi-head attention's layout for the head dimension is (numHeads:headStride),
|
||||||
|
for single-head attention it is simply (1:0) everywhere,
|
||||||
|
for GQA it is normal in Q and (numHeads/numGroups,numGroups:headStride,0) in KV,
|
||||||
|
and for MQA it is normal for Q and (numHeads:0) in KV.
|
||||||
|
As such, beyond general stride handling, no additional work is needed to support these,
|
||||||
|
and the example will just demonstrate regular multi-head attention.
|
||||||
|
|
||||||
|
### FP8
|
||||||
|
|
||||||
|
The warp-specialized forward kernel supports FP8 computation with both FP32 and FP16
|
||||||
|
accumulation for the Q*K product. They can be enabled in the runner by defining FP8.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
Forward pass kernels can generally come close to that of FA3, but backward pass
|
||||||
|
kernels are more limited in performance and are not expected to reach the same level of performance
|
||||||
|
as FA3.
|
||||||
|
|
||||||
|
# Copyright
|
||||||
|
|
||||||
|
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
```
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
```
|
||||||
@ -0,0 +1,863 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "../collective/fmha_common.hpp"
|
||||||
|
#include "../collective/fmha_collective_load.hpp"
|
||||||
|
#include "../collective/fmha_collective_softmax.hpp"
|
||||||
|
#include "../kernel/fmha_options.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
template<
|
||||||
|
typename Element_,
|
||||||
|
typename ElementAccumulator_,
|
||||||
|
typename TileShape_, // BlockQO, BlockKV, BlockHead
|
||||||
|
class Fusion,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaBwdMainloopTmaWarpSpecialized {
|
||||||
|
|
||||||
|
using Element = Element_;
|
||||||
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
|
using TileShape = TileShape_;
|
||||||
|
|
||||||
|
static constexpr bool kIsPersistent = false;
|
||||||
|
|
||||||
|
static const int NumLoadWarpGroups = 1;
|
||||||
|
static constexpr int NumMmaWarpGroups = 2;
|
||||||
|
static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups;
|
||||||
|
static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */;
|
||||||
|
|
||||||
|
static const int kOuterLoads = 2;
|
||||||
|
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||||
|
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
static_assert(StagesQ::value >= 2);
|
||||||
|
static_assert(Stages::value >= 2 * NumMmaWarpGroups);
|
||||||
|
|
||||||
|
// 16B alignment lets us use TMA
|
||||||
|
static constexpr int Alignment = 16 / sizeof(Element);
|
||||||
|
|
||||||
|
using TileShapeNM = Shape< // (N,M,D)
|
||||||
|
decltype(tuple_element_t<1, TileShape>{} / Int<NumMmaWarpGroups>{}),
|
||||||
|
tuple_element_t<0, TileShape>,
|
||||||
|
tuple_element_t<2, TileShape>>;
|
||||||
|
|
||||||
|
using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M)
|
||||||
|
|
||||||
|
using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N)
|
||||||
|
|
||||||
|
using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||||
|
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShapeNM, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
|
||||||
|
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShapeND, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment, // from register, doesn't matter
|
||||||
|
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShapeND, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
|
||||||
|
using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment, // from smem, might matter (?)
|
||||||
|
Element, cute::tuple<_1, int, cute::tuple<int, int>>, Alignment,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShapeMD, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
using TiledMmaNM = typename CollectiveMmaNM::TiledMma;
|
||||||
|
using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma;
|
||||||
|
using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{}));
|
||||||
|
using TiledMmaND = TiledMmaND_RS;
|
||||||
|
using TiledMmaMD = typename CollectiveMmaMD::TiledMma;
|
||||||
|
|
||||||
|
using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB;
|
||||||
|
using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA;
|
||||||
|
using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA;
|
||||||
|
using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB;
|
||||||
|
|
||||||
|
//using SmemLayoutDQ = Layout<
|
||||||
|
// Shape<
|
||||||
|
// tuple_element_t<0, TileShapeMD>,
|
||||||
|
// Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>,
|
||||||
|
// _2
|
||||||
|
// >,
|
||||||
|
// Stride<
|
||||||
|
// _4,
|
||||||
|
// Stride<decltype(tuple_element_t<0, TileShapeMD>{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>,
|
||||||
|
// decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
|
||||||
|
// >>;
|
||||||
|
|
||||||
|
using SmemLayoutDQ_0 = Layout<
|
||||||
|
Shape<
|
||||||
|
tuple_element_t<0, TileShapeMD>,
|
||||||
|
tuple_element_t<1, TileShapeMD>,
|
||||||
|
_2
|
||||||
|
>,
|
||||||
|
Stride<
|
||||||
|
tuple_element_t<1, TileShapeMD>,
|
||||||
|
_1,
|
||||||
|
decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{})
|
||||||
|
>>;
|
||||||
|
|
||||||
|
using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||||
|
cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>());
|
||||||
|
using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{}));
|
||||||
|
using SmemLayoutDQ = SmemLayoutDQ_1;
|
||||||
|
|
||||||
|
|
||||||
|
using PipelineDQ = cutlass::PipelineAsync<2>;
|
||||||
|
|
||||||
|
|
||||||
|
using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int<NumMmaWarpGroups>{}));
|
||||||
|
|
||||||
|
using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom<Element>{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{}));
|
||||||
|
using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB;
|
||||||
|
|
||||||
|
using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB;
|
||||||
|
using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB;
|
||||||
|
|
||||||
|
using SmemLayoutLSE = Layout<Shape<tuple_element_t<1, TileShapeNM>, Int<StageCount>>>;
|
||||||
|
|
||||||
|
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||||
|
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||||
|
|
||||||
|
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||||
|
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||||
|
|
||||||
|
using TileShapePV = TileShapeND; // To work with the kernel level
|
||||||
|
using TiledMmaPV = TiledMmaND;
|
||||||
|
|
||||||
|
static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator);
|
||||||
|
static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
|
||||||
|
|
||||||
|
struct SharedStorage {
|
||||||
|
// One for each consumer WG
|
||||||
|
union {
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutKp>> smem_kp;
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||||
|
};
|
||||||
|
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDS>> smem_ds;
|
||||||
|
|
||||||
|
// Loaded by producer, consumed by both WGs
|
||||||
|
union {
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDO>> smem_do;
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQp>> smem_qp;
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutDOp>> smem_dop;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Accumulated into by both consumers, potentially loaded, potentially written
|
||||||
|
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutDQ>> smem_dq;
|
||||||
|
|
||||||
|
union {
|
||||||
|
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_lse;
|
||||||
|
cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayoutLSE>> smem_sumOdO;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
const Element* ptr_Q;
|
||||||
|
cute::tuple<int, int, int, _1> dQ;
|
||||||
|
const Element* ptr_K;
|
||||||
|
cute::tuple<int, int, int, _1> dK;
|
||||||
|
const Element* ptr_V;
|
||||||
|
cute::tuple<int, int, int, _1> dV;
|
||||||
|
|
||||||
|
const Element* ptr_dO;
|
||||||
|
cute::tuple<int, int, int, _1> dDO;
|
||||||
|
|
||||||
|
const ElementAccumulator* ptr_LSE;
|
||||||
|
cute::tuple<int, int, _1> dLSE;
|
||||||
|
const ElementAccumulator* ptr_sum_OdO;
|
||||||
|
cute::tuple<int, int, _1> dSumOdO;
|
||||||
|
|
||||||
|
ElementAccumulator* ptr_dQ;
|
||||||
|
cute::tuple<int, int, int, _1> dDQ;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TMA_Q = typename CollectiveMmaNM::Params::TMA_B;
|
||||||
|
using TMA_K = typename CollectiveMmaNM::Params::TMA_A;
|
||||||
|
using TMA_V = typename CollectiveMmaNM::Params::TMA_A;
|
||||||
|
using TMA_DO = typename CollectiveMmaNM::Params::TMA_B;
|
||||||
|
|
||||||
|
using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{})));
|
||||||
|
using TMA_ODO = TMA_LSE;
|
||||||
|
|
||||||
|
using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{})));
|
||||||
|
|
||||||
|
using LoadQ = CollectiveLoadTma<
|
||||||
|
LoadKind::kBwdM,
|
||||||
|
MainloopPipeline,
|
||||||
|
Element,
|
||||||
|
SmemLayoutQ,
|
||||||
|
TMA_Q
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadK = CollectiveLoadTma<
|
||||||
|
LoadKind::kBwdN,
|
||||||
|
MainloopPipelineQ,
|
||||||
|
Element,
|
||||||
|
SmemLayoutK,
|
||||||
|
TMA_K
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadV = CollectiveLoadTma<
|
||||||
|
LoadKind::kBwdN,
|
||||||
|
MainloopPipelineQ,
|
||||||
|
Element,
|
||||||
|
SmemLayoutV,
|
||||||
|
TMA_V
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadDO = CollectiveLoadTma<
|
||||||
|
LoadKind::kBwdM,
|
||||||
|
MainloopPipeline,
|
||||||
|
Element,
|
||||||
|
SmemLayoutDO,
|
||||||
|
TMA_DO
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadLSE = CollectiveLoadTma<
|
||||||
|
LoadKind::kBwdScalar,
|
||||||
|
MainloopPipeline,
|
||||||
|
ElementAccumulator,
|
||||||
|
SmemLayoutLSE,
|
||||||
|
TMA_LSE
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadODO = CollectiveLoadTma<
|
||||||
|
LoadKind::kBwdScalar,
|
||||||
|
MainloopPipeline,
|
||||||
|
ElementAccumulator,
|
||||||
|
SmemLayoutLSE,
|
||||||
|
TMA_ODO
|
||||||
|
>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
TMA_Q tma_load_q;
|
||||||
|
TMA_K tma_load_k;
|
||||||
|
TMA_V tma_load_v;
|
||||||
|
TMA_DO tma_load_do;
|
||||||
|
|
||||||
|
TMA_LSE tma_load_lse;
|
||||||
|
TMA_ODO tma_load_odo;
|
||||||
|
|
||||||
|
TMA_DQ tma_red_dq;
|
||||||
|
|
||||||
|
float scale_softmax;
|
||||||
|
float scale_softmax_log2;
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(size(TiledMmaNM{}) == size(TiledMmaND{}));
|
||||||
|
static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{}));
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||||
|
return true
|
||||||
|
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||||
|
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||||
|
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||||
|
auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||||
|
|
||||||
|
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK)));
|
||||||
|
auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ)));
|
||||||
|
auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
|
||||||
|
typename CollectiveMmaNM::Arguments {
|
||||||
|
args.ptr_K, dK,
|
||||||
|
args.ptr_Q, dQ,
|
||||||
|
}, /*workspace=*/ nullptr);
|
||||||
|
|
||||||
|
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV)));
|
||||||
|
auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO)));
|
||||||
|
auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm,
|
||||||
|
typename CollectiveMmaNM::Arguments {
|
||||||
|
args.ptr_V, dV,
|
||||||
|
args.ptr_dO, dDO,
|
||||||
|
}, /*workspace=*/ nullptr);
|
||||||
|
|
||||||
|
|
||||||
|
TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{}));
|
||||||
|
TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{}));
|
||||||
|
|
||||||
|
TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{}));
|
||||||
|
|
||||||
|
return Params{
|
||||||
|
params_nm_kq.tma_load_b,
|
||||||
|
params_nm_kq.tma_load_a,
|
||||||
|
params_nm_vdo.tma_load_a,
|
||||||
|
params_nm_vdo.tma_load_b,
|
||||||
|
tma_load_lse, tma_load_odo,
|
||||||
|
tma_red_dq,
|
||||||
|
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||||
|
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size)))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
auto
|
||||||
|
get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) {
|
||||||
|
return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void prefetch_tma_descriptors(Params const& params) {
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool kLoadOuter, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load_kv_maybe_q(
|
||||||
|
int block_rank_in_cluster,
|
||||||
|
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner,
|
||||||
|
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
|
||||||
|
SharedStorage& storage,
|
||||||
|
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||||
|
{
|
||||||
|
// Load pattern:
|
||||||
|
// K0 V0 K1 V1
|
||||||
|
// Q0 DO0 Q1 DO1 Q2 DO2 ...
|
||||||
|
// K0 Q0 V0 K1 DO0 V1 ...
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
int outer_tile_count = NumMmaWarpGroups;
|
||||||
|
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
|
||||||
|
|
||||||
|
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
|
||||||
|
auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count);
|
||||||
|
|
||||||
|
uint16_t mcast_mask_b = 0;
|
||||||
|
|
||||||
|
LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q};
|
||||||
|
auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
|
||||||
|
|
||||||
|
LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do};
|
||||||
|
auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count);
|
||||||
|
|
||||||
|
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
|
||||||
|
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||||
|
|
||||||
|
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
|
||||||
|
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||||
|
|
||||||
|
LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse};
|
||||||
|
auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||||
|
|
||||||
|
LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO};
|
||||||
|
auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||||
|
|
||||||
|
outer_tile_count *= 2; // K & V
|
||||||
|
inner_tile_count *= 4; // Q & dO & LSE & sumOdO
|
||||||
|
|
||||||
|
while (inner_tile_count > 0) {
|
||||||
|
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
inner_tile_count -= 4;
|
||||||
|
++inner_tile_iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kLoadOuter) {
|
||||||
|
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
|
||||||
|
if constexpr (! kLoadOuter) {
|
||||||
|
if (do_barrier) {
|
||||||
|
load_warp_barrier.arrive();
|
||||||
|
load_warp_barrier.wait(/*phase=*/ 0);
|
||||||
|
do_barrier = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kLoadOuter) {
|
||||||
|
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
|
||||||
|
if constexpr (kLoadOuter) {
|
||||||
|
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kLoadOuter) {
|
||||||
|
while (outer_tile_count > 0) {
|
||||||
|
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
while (inner_tile_count > 0) {
|
||||||
|
while (inner_tile_count > 0) {
|
||||||
|
if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
inner_tile_count -= 4;
|
||||||
|
++inner_tile_iter;
|
||||||
|
}
|
||||||
|
load_q.template step<false,false,true>(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
load_lse.template step<false,true,false>(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
|
||||||
|
load_do.template step<false,false,true>(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
load_odo.template step<true,true,false>(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load_maybe_q(
|
||||||
|
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer,
|
||||||
|
SharedStorage& storage,
|
||||||
|
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||||
|
{
|
||||||
|
// Load pattern:
|
||||||
|
// K0 V0 K1 V1
|
||||||
|
// Q0 DO0 Q1 DO1 Q2 DO2 ...
|
||||||
|
// K0 Q0 V0 K1 DO0 V1 ...
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
int outer_tile_count = NumMmaWarpGroups;
|
||||||
|
|
||||||
|
auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count);
|
||||||
|
|
||||||
|
LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k};
|
||||||
|
auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||||
|
|
||||||
|
LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v};
|
||||||
|
auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count);
|
||||||
|
|
||||||
|
outer_tile_count *= 2; // K & V
|
||||||
|
|
||||||
|
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
|
||||||
|
if (do_barrier) {
|
||||||
|
load_warp_barrier.arrive();
|
||||||
|
load_warp_barrier.wait(/*phase=*/ 0);
|
||||||
|
do_barrier = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
|
||||||
|
while (outer_tile_count > 0) {
|
||||||
|
load_k.template step<false>(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
load_v.template step<true>(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
reduce(
|
||||||
|
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer,
|
||||||
|
SharedStorage& storage)
|
||||||
|
{
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size));
|
||||||
|
Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{});
|
||||||
|
Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord));
|
||||||
|
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
|
||||||
|
|
||||||
|
auto block_tma = params.tma_red_dq.get_slice(_0{});
|
||||||
|
|
||||||
|
Tensor tDQsDQ = block_tma.partition_S(sDQ);
|
||||||
|
Tensor tDQgDQ = block_tma.partition_D(gDQ);
|
||||||
|
|
||||||
|
int inner_tile_count = get_inner_tile_count(blk_coord, problem_size);
|
||||||
|
int g_index = 0;
|
||||||
|
|
||||||
|
auto smem_pipe_release_reducer = smem_pipe_read_reducer;
|
||||||
|
bool first = true;
|
||||||
|
while (inner_tile_count > 0) {
|
||||||
|
while (inner_tile_count > 0) {
|
||||||
|
if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
inner_tile_count -= 1;
|
||||||
|
++g_index;
|
||||||
|
}
|
||||||
|
if (inner_tile_count == 0) break;
|
||||||
|
|
||||||
|
pipeline_reducer.consumer_wait(smem_pipe_read_reducer);
|
||||||
|
if (lane_predicate == 1) {
|
||||||
|
tma_store_wait<1>();
|
||||||
|
}
|
||||||
|
if (! first) {
|
||||||
|
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
|
||||||
|
++smem_pipe_release_reducer;
|
||||||
|
} else {
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
if (lane_predicate == 1) {
|
||||||
|
copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index));
|
||||||
|
tma_store_arrive();
|
||||||
|
}
|
||||||
|
++smem_pipe_read_reducer;
|
||||||
|
--inner_tile_count;
|
||||||
|
++g_index;
|
||||||
|
}
|
||||||
|
if (lane_predicate) {
|
||||||
|
tma_store_wait<0>();
|
||||||
|
}
|
||||||
|
pipeline_reducer.consumer_release(smem_pipe_release_reducer);
|
||||||
|
++smem_pipe_release_reducer;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
compute(
|
||||||
|
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
|
||||||
|
Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner,
|
||||||
|
MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer,
|
||||||
|
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
|
||||||
|
SharedStorage& storage,
|
||||||
|
MathWgOrderBarrier& math_wg_order_barrier)
|
||||||
|
{
|
||||||
|
TiledMmaND tiled_mma_nd;
|
||||||
|
|
||||||
|
Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
|
||||||
|
clear(acc_DV);
|
||||||
|
|
||||||
|
Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{}));
|
||||||
|
clear(acc_DK);
|
||||||
|
|
||||||
|
int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup;
|
||||||
|
|
||||||
|
PipelineState smem_pipe_release_inner = smem_pipe_read_inner;
|
||||||
|
|
||||||
|
pipeline_outer.consumer_wait(smem_pipe_read_outer);
|
||||||
|
PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer;
|
||||||
|
++smem_pipe_read_outer;
|
||||||
|
pipeline_outer.consumer_wait(smem_pipe_read_outer);
|
||||||
|
PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer;
|
||||||
|
|
||||||
|
int inner_tile_count = get_inner_tile_count(wg_coord, problem_size);
|
||||||
|
|
||||||
|
TiledMmaNM tiled_mma_nm;
|
||||||
|
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||||
|
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||||
|
auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx);
|
||||||
|
Tensor tSsK = thr_mma_nm.partition_A(sK);
|
||||||
|
Tensor tSsQ = thr_mma_nm.partition_B(sQ);
|
||||||
|
Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK);
|
||||||
|
Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ);
|
||||||
|
|
||||||
|
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||||
|
Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{});
|
||||||
|
|
||||||
|
Tensor tDPsV = thr_mma_nm.partition_A(sV);
|
||||||
|
Tensor tDPsDO = thr_mma_nm.partition_B(sDO);
|
||||||
|
Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV);
|
||||||
|
Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO);
|
||||||
|
|
||||||
|
auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx);
|
||||||
|
|
||||||
|
Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{});
|
||||||
|
Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp);
|
||||||
|
Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO);
|
||||||
|
|
||||||
|
|
||||||
|
Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{});
|
||||||
|
Tensor tDK_sQ = thr_mma_nd.partition_B(sQp);
|
||||||
|
Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ);
|
||||||
|
|
||||||
|
|
||||||
|
int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0);
|
||||||
|
|
||||||
|
TiledMmaMD tiled_mma_md;
|
||||||
|
auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx);
|
||||||
|
Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{});
|
||||||
|
Tensor tDQsDS = thr_mma_md.partition_A(sDS);
|
||||||
|
Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS);
|
||||||
|
Tensor tDQrDS = tDQrDS_full(_,_,_,_);
|
||||||
|
Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{});
|
||||||
|
Tensor tDQsK = thr_mma_md.partition_B(sKp);
|
||||||
|
Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK);
|
||||||
|
|
||||||
|
Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
|
||||||
|
Tensor tSsLSE = thr_mma_nm.partition_C(sLSE);
|
||||||
|
|
||||||
|
Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int<StageCount>{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{})));
|
||||||
|
Tensor tDPsODO = thr_mma_nm.partition_C(sODO);
|
||||||
|
|
||||||
|
Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{}));
|
||||||
|
Tensor tScS = thr_mma_nm.partition_C(cS);
|
||||||
|
int n_block = get<1>(wg_coord);
|
||||||
|
tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{});
|
||||||
|
|
||||||
|
|
||||||
|
// Transpose
|
||||||
|
Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS))));
|
||||||
|
Tensor sDSp = sDSp_full(_,_,_);
|
||||||
|
Tensor tDPsDS = thr_mma_nm.partition_C(sDSp);
|
||||||
|
|
||||||
|
auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx);
|
||||||
|
Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp);
|
||||||
|
|
||||||
|
Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp);
|
||||||
|
|
||||||
|
Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{});
|
||||||
|
auto tDQsDQ_full = thr_mma_md.partition_C(sDQ);
|
||||||
|
|
||||||
|
|
||||||
|
auto smem_pipe_read_k_other = smem_pipe_read_k;
|
||||||
|
smem_pipe_read_k_other.advance(2);
|
||||||
|
|
||||||
|
int k_index = 0;
|
||||||
|
|
||||||
|
while (inner_tile_count > 0) {
|
||||||
|
while (inner_tile_count > 0) {
|
||||||
|
if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
inner_tile_count -= 1;
|
||||||
|
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
|
||||||
|
k_index += 1;
|
||||||
|
}
|
||||||
|
if (inner_tile_count == 0) break;
|
||||||
|
|
||||||
|
pipeline_inner.consumer_wait(smem_pipe_read_inner);
|
||||||
|
PipelineState smem_pipe_read_q = smem_pipe_read_inner;
|
||||||
|
++smem_pipe_read_inner;
|
||||||
|
PipelineState smem_pipe_read_do = smem_pipe_read_inner;
|
||||||
|
++smem_pipe_read_inner;
|
||||||
|
|
||||||
|
// GEMM KQ -> S
|
||||||
|
Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
|
||||||
|
|
||||||
|
warpgroup_fence_operand(acc_S);
|
||||||
|
warpgroup_arrive();
|
||||||
|
gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
pipeline_inner.consumer_wait(smem_pipe_read_do);
|
||||||
|
|
||||||
|
// GEMM VdO -> dP
|
||||||
|
Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{}));
|
||||||
|
|
||||||
|
warpgroup_fence_operand(acc_DP);
|
||||||
|
warpgroup_arrive();
|
||||||
|
gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
Tensor reg_LSE = make_fragment_like<ElementAccumulator>(acc_S);
|
||||||
|
for (int i = 0; i < size(reg_LSE); i++) {
|
||||||
|
reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor reg_ODO = make_fragment_like<ElementAccumulator>(acc_S);
|
||||||
|
if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) {
|
||||||
|
for (int i = 0; i < size(reg_ODO); i++) {
|
||||||
|
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
warpgroup_wait<1>();
|
||||||
|
warpgroup_fence_operand(acc_S);
|
||||||
|
|
||||||
|
math_wg_order_barrier.wait();
|
||||||
|
// Compute S -> P
|
||||||
|
Fusion{}.before_softmax(acc_S, tScS, problem_size);
|
||||||
|
auto acc_P = make_fragment_like<ElementAccumulator>(acc_S);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_P); i++) {
|
||||||
|
acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i));
|
||||||
|
}
|
||||||
|
math_wg_order_barrier.arrive();
|
||||||
|
|
||||||
|
if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) {
|
||||||
|
for (int i = 0; i < size(reg_ODO); i++) {
|
||||||
|
reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_DP);
|
||||||
|
|
||||||
|
// Compute dP P -> dS
|
||||||
|
auto acc_DS = make_fragment_like<Element>(acc_DP);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_DS); i++) {
|
||||||
|
// We could move the scale out and into the respective epilogues (or a final scaling step)
|
||||||
|
acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// GEMM PdO -> dV
|
||||||
|
auto op_P = make_acc_into_op<Element>(acc_P, typename TiledMmaND::LayoutA_TV{});
|
||||||
|
warpgroup_fence_operand(acc_DV);
|
||||||
|
warpgroup_fence_operand(op_P);
|
||||||
|
warpgroup_arrive();
|
||||||
|
cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
// Store dS to smem dS'
|
||||||
|
if (wg_idx == 0) math_wg_order_barrier.wait();
|
||||||
|
|
||||||
|
auto recast_bits = [](auto sz, auto t) {
|
||||||
|
return recast<uint_bit_t<decltype(sz)::value>>(t);
|
||||||
|
};
|
||||||
|
auto tDPsDS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, tDPsDS);
|
||||||
|
auto acc_DS_v = recast_bits(Int<sizeof_bits_v<Element> * 2>{}, acc_DS);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_DS_v); i++) {
|
||||||
|
tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
cutlass::arch::fence_view_async_shared();
|
||||||
|
if (wg_idx == 0) math_wg_order_barrier.arrive();
|
||||||
|
|
||||||
|
// GEMM dS Q -> dK
|
||||||
|
if (wg_idx == 1) {
|
||||||
|
|
||||||
|
math_wg_order_barrier.wait();
|
||||||
|
|
||||||
|
// GEMM dS' K -> dQ
|
||||||
|
Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{}));
|
||||||
|
|
||||||
|
warpgroup_fence_operand(acc_DQ);
|
||||||
|
warpgroup_arrive();
|
||||||
|
gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ);
|
||||||
|
cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
warpgroup_fence_operand(acc_DK);
|
||||||
|
warpgroup_arrive();
|
||||||
|
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
warpgroup_wait<1>();
|
||||||
|
warpgroup_fence_operand(acc_DK);
|
||||||
|
|
||||||
|
warpgroup_wait<1>();
|
||||||
|
warpgroup_fence_operand(acc_DQ);
|
||||||
|
|
||||||
|
math_wg_order_barrier.arrive();
|
||||||
|
|
||||||
|
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
|
||||||
|
auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index());
|
||||||
|
|
||||||
|
// Store dQ to smem dQ'
|
||||||
|
// Invoke TMA reduce on dQ'
|
||||||
|
using Vec = uint_bit_t<sizeof_bits_v<ElementAccumulator> * 2>;
|
||||||
|
auto tDQsDQ_v = recast<Vec>(tDQsDQ);
|
||||||
|
auto acc_DQ_v = recast<Vec>(acc_DQ);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_DQ_v); i++) {
|
||||||
|
tDQsDQ_v(i) = acc_DQ_v(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
cutlass::arch::fence_view_async_shared();
|
||||||
|
|
||||||
|
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
|
||||||
|
++smem_pipe_write_reducer;
|
||||||
|
} else {
|
||||||
|
|
||||||
|
warpgroup_fence_operand(acc_DK);
|
||||||
|
warpgroup_arrive();
|
||||||
|
cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
warpgroup_wait<1>();
|
||||||
|
warpgroup_fence_operand(acc_DK);
|
||||||
|
|
||||||
|
pipeline_reducer.producer_acquire(smem_pipe_write_reducer);
|
||||||
|
pipeline_reducer.producer_commit(smem_pipe_write_reducer);
|
||||||
|
++smem_pipe_write_reducer;
|
||||||
|
}
|
||||||
|
|
||||||
|
--inner_tile_count;
|
||||||
|
|
||||||
|
pipeline_inner.consumer_release(smem_pipe_release_inner);
|
||||||
|
++smem_pipe_release_inner;
|
||||||
|
pipeline_inner.consumer_release(smem_pipe_release_inner);
|
||||||
|
++smem_pipe_release_inner;
|
||||||
|
|
||||||
|
tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{});
|
||||||
|
k_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline_outer.consumer_release(smem_pipe_read_k);
|
||||||
|
pipeline_outer.consumer_release(smem_pipe_read_outer);
|
||||||
|
pipeline_reducer.producer_tail(smem_pipe_write_reducer);
|
||||||
|
++smem_pipe_read_outer;
|
||||||
|
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_DK);
|
||||||
|
warpgroup_fence_operand(acc_DV);
|
||||||
|
|
||||||
|
return make_tuple(acc_DK, acc_DV);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
140
examples/88_hopper_fmha/collective/fmha_collective_load.hpp
Normal file
140
examples/88_hopper_fmha/collective/fmha_collective_load.hpp
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
enum class LoadKind {
|
||||||
|
kQ, kK, kV,
|
||||||
|
kBwdN, kBwdM, kBwdScalar
|
||||||
|
};
|
||||||
|
|
||||||
|
template<
|
||||||
|
LoadKind kKind,
|
||||||
|
class Pipeline,
|
||||||
|
class Element,
|
||||||
|
class SmemLayout,
|
||||||
|
class TMA
|
||||||
|
>
|
||||||
|
struct CollectiveLoadTma {
|
||||||
|
|
||||||
|
using Params = TMA;
|
||||||
|
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayout>>;
|
||||||
|
using PipelineState = typename cutlass::PipelineState<Pipeline::Stages>;
|
||||||
|
|
||||||
|
Params const& params;
|
||||||
|
Pipeline& pipeline;
|
||||||
|
SharedStorage& storage;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage)
|
||||||
|
: params(params), pipeline(pipeline), storage(storage) {}
|
||||||
|
|
||||||
|
template<class ProblemSize, class TileShape, class BlockCoord>
|
||||||
|
CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape,
|
||||||
|
BlockCoord const& blk_coord, int loop_count
|
||||||
|
) {
|
||||||
|
using X = Underscore;
|
||||||
|
if constexpr (kKind == LoadKind::kK) {
|
||||||
|
Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||||
|
Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||||
|
Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||||
|
return gK;
|
||||||
|
} else if constexpr (kKind == LoadKind::kQ) {
|
||||||
|
Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||||
|
Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||||
|
Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||||
|
return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout());
|
||||||
|
} else if constexpr (kKind == LoadKind::kV) {
|
||||||
|
Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size)));
|
||||||
|
Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||||
|
Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord));
|
||||||
|
return gV;
|
||||||
|
} else if constexpr (kKind == LoadKind::kBwdN) {
|
||||||
|
Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||||
|
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{});
|
||||||
|
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||||
|
return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout());
|
||||||
|
} else if constexpr (kKind == LoadKind::kBwdM) {
|
||||||
|
Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size)));
|
||||||
|
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, _1>{});
|
||||||
|
Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord));
|
||||||
|
return g;
|
||||||
|
} else if constexpr (kKind == LoadKind::kBwdScalar) {
|
||||||
|
Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size));
|
||||||
|
Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<X, _1, X>{});
|
||||||
|
Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord));
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ClusterRank, class ProblemSize, class TileShape, class BlockCoord>
|
||||||
|
CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster,
|
||||||
|
ProblemSize const& problem_size, TileShape const& tile_shape,
|
||||||
|
BlockCoord const& block_coord, int loop_count
|
||||||
|
) {
|
||||||
|
Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count);
|
||||||
|
Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{});
|
||||||
|
|
||||||
|
auto block_tma = params.get_slice(block_rank_in_cluster);
|
||||||
|
Tensor ts = block_tma.partition_D(s);
|
||||||
|
Tensor tg = block_tma.partition_S(g);
|
||||||
|
|
||||||
|
return make_tuple(tg, ts);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool kAdvanceIterator=true, bool kAdvancePipe=true, bool kAcquireBarrier=true, class TileIterator, class State>
|
||||||
|
CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state,
|
||||||
|
PipelineState& smem_pipe_write,
|
||||||
|
int lane_predicate, int& tile_count, uint16_t mcast_mask = 0
|
||||||
|
) {
|
||||||
|
if ((lane_predicate == 1) && (tile_count > 0)) {
|
||||||
|
if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write);
|
||||||
|
using BarrierType = typename Pipeline::ProducerBarrierType;
|
||||||
|
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||||
|
|
||||||
|
if constexpr (kKind == LoadKind::kBwdScalar) {
|
||||||
|
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index()));
|
||||||
|
} else {
|
||||||
|
copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index()));
|
||||||
|
}
|
||||||
|
if constexpr (kAdvancePipe) ++smem_pipe_write;
|
||||||
|
if constexpr (kAdvanceIterator) ++tile_iter;
|
||||||
|
}
|
||||||
|
--tile_count;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
305
examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp
Normal file
305
examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
#include "../collective/fmha_common.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ElementAccumulator,
|
||||||
|
class Fusion,
|
||||||
|
class Params
|
||||||
|
>
|
||||||
|
struct CollectiveSoftmax {
|
||||||
|
Params const& params;
|
||||||
|
CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {}
|
||||||
|
|
||||||
|
using SumType = float;
|
||||||
|
using MaxType = ElementAccumulator;
|
||||||
|
|
||||||
|
template<class AccPV, class TiledMmaPV>
|
||||||
|
CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) {
|
||||||
|
Tensor s_max = make_fragment_like<MaxType>(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout())));
|
||||||
|
Tensor a_sum = make_fragment_like<SumType>(s_max);
|
||||||
|
return make_tuple(s_max, a_sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE float overload_exp2(float f) {
|
||||||
|
return ::exp2f(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) {
|
||||||
|
auto a = f.raw();
|
||||||
|
decltype(a) d;
|
||||||
|
asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a));
|
||||||
|
return cutlass::half_t::bitcast(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CUTLASS_DEVICE float overload_max(float a, float b) {
|
||||||
|
return ::max(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) {
|
||||||
|
return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) {
|
||||||
|
return f.to_half();
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE float overload_to_native(float f) {
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class AccQK, class TiledMmaQK, class CountQK, class State, class ProblemShape>
|
||||||
|
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) {
|
||||||
|
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||||
|
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||||
|
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||||
|
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||||
|
|
||||||
|
auto& s_max = get<0>(state);
|
||||||
|
auto& a_sum = get<1>(state);
|
||||||
|
|
||||||
|
// Linear reduction is faster for the first iteration
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
s_max(i) = acc_qk_mn(i, 0);
|
||||||
|
}
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 1; j < size<1>(acc_qk_mn); j++) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||||
|
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
|
||||||
|
MaxType scale_max = scale * local_max;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||||
|
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
|
||||||
|
CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
|
||||||
|
|
||||||
|
if constexpr (kUseFusion) {
|
||||||
|
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||||
|
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||||
|
|
||||||
|
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
|
||||||
|
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||||
|
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||||
|
|
||||||
|
auto& s_max = get<0>(state);
|
||||||
|
auto& a_sum = get<1>(state);
|
||||||
|
|
||||||
|
Tensor s_max_prev = make_fragment_like(s_max);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
s_max_prev(i) = s_max(i);
|
||||||
|
}
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
// Linear reduction is faster here, as well
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||||
|
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// reduce max
|
||||||
|
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride<r>(reduction_target_qk) * j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
|
||||||
|
float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i);
|
||||||
|
float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2);
|
||||||
|
a_sum(i) *= scale;
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
|
||||||
|
acc_pv_mn(i, j) *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class AccQK_MN, class State>
|
||||||
|
CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) {
|
||||||
|
|
||||||
|
auto& s_max = get<0>(state);
|
||||||
|
auto& a_sum = get<1>(state);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<0>(acc_qk_mn); j++) {
|
||||||
|
float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j);
|
||||||
|
float scale_max = params.scale_softmax_log2 * local_max;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k = 0; k < size<1>(acc_qk_mn); k++) {
|
||||||
|
acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max);
|
||||||
|
a_sum(j) += acc_qk_mn(j, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool kUseFusion=true, class AccQK, class TiledMmaQK, class CountQK, class State, class AccPV, class TiledMmaPV, class ProblemShape>
|
||||||
|
CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) {
|
||||||
|
|
||||||
|
if constexpr (kUseFusion) {
|
||||||
|
Fusion{}.before_softmax(acc_qk, count_qk, problem_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout()));
|
||||||
|
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||||
|
|
||||||
|
static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn));
|
||||||
|
auto reduction_target_qk = reduction_target_n(tiled_mma_qk);
|
||||||
|
constexpr int red_rank = decltype(rank(reduction_target_qk))::value;
|
||||||
|
|
||||||
|
auto& s_max = get<0>(state);
|
||||||
|
auto& a_sum = get<1>(state);
|
||||||
|
|
||||||
|
Tensor s_max_prev = make_fragment_like(s_max);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_qk_mn); i++) {
|
||||||
|
s_max_prev(i) = s_max(i);
|
||||||
|
|
||||||
|
// Linear reduction is faster here, as well
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||||
|
s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j));
|
||||||
|
}
|
||||||
|
// reduce max
|
||||||
|
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 1; j < shape<r>(reduction_target_qk); j *= 2) {
|
||||||
|
s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride<r>(reduction_target_qk) * j)});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
MaxType local_max = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||||
|
MaxType scale = static_cast<MaxType>(params.scale_softmax_log2);
|
||||||
|
MaxType scale_max = scale * local_max;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(acc_qk_mn); j++) {
|
||||||
|
acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
MaxType s_max_cur = s_max(i) == static_cast<MaxType>(-INFINITY) ? static_cast<MaxType>(0) : s_max(i);
|
||||||
|
SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale);
|
||||||
|
a_sum(i) *= scale_pv;
|
||||||
|
|
||||||
|
using ElementPV = typename AccPV::value_type;
|
||||||
|
ElementPV scale_pv_ele = static_cast<ElementPV>(scale_pv);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(acc_pv_mn); j++) {
|
||||||
|
acc_pv_mn(i, j) *= scale_pv_ele;
|
||||||
|
}
|
||||||
|
a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<class State, class AccPV, class TiledMmaPV>
|
||||||
|
CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) {
|
||||||
|
auto& s_max = get<0>(state);
|
||||||
|
auto& a_sum = get<1>(state);
|
||||||
|
|
||||||
|
Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||||
|
|
||||||
|
auto reduction_target = reduction_target_n(tiled_mma_pv);
|
||||||
|
constexpr int red_rank = decltype(rank(reduction_target))::value;
|
||||||
|
for_each(make_seq<red_rank>{}, [&](auto r) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 1; j < shape<r>(reduction_target); j *= 2) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_pv_mn); i++) {
|
||||||
|
a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride<r>(reduction_target) * j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout()));
|
||||||
|
|
||||||
|
Tensor lse = make_fragment_like(a_sum);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(acc_mn); i++) {
|
||||||
|
float sum = a_sum(i);
|
||||||
|
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum);
|
||||||
|
lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum);
|
||||||
|
float scale = params.rp_dropout * inv_sum;
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(acc_mn); j++) {
|
||||||
|
acc_mn(i, j) *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lse;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
526
examples/88_hopper_fmha/collective/fmha_collective_tma.hpp
Normal file
526
examples/88_hopper_fmha/collective/fmha_collective_tma.hpp
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "../collective/fmha_common.hpp"
|
||||||
|
#include "../collective/fmha_collective_load.hpp"
|
||||||
|
#include "../collective/fmha_collective_softmax.hpp"
|
||||||
|
#include "../kernel/fmha_options.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using cutlass::fmha::kernel::Tag;
|
||||||
|
using cutlass::fmha::kernel::find_option_t;
|
||||||
|
|
||||||
|
template<
|
||||||
|
typename Element_,
|
||||||
|
typename ElementAccumulator_,
|
||||||
|
typename TileShape_, // BlockQO, BlockKV, BlockHead
|
||||||
|
class Fusion,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaMainloopTma {
|
||||||
|
|
||||||
|
using Element = Element_;
|
||||||
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
|
using TileShape = TileShape_;
|
||||||
|
|
||||||
|
// Options
|
||||||
|
using kClusterM = find_option_t<Tag::kClusterM, Int<1>, Options...>;
|
||||||
|
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<4>, Options...>::value;
|
||||||
|
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<1>, Options...>::value;
|
||||||
|
|
||||||
|
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||||
|
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||||
|
using ClusterShape = Shape<kClusterM, _1, _1>;
|
||||||
|
|
||||||
|
// 16B alignment lets us use TMA
|
||||||
|
static constexpr int Alignment = 16 / sizeof(Element);
|
||||||
|
|
||||||
|
using TileShapeQK = TileShape;
|
||||||
|
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
|
||||||
|
|
||||||
|
using LayoutQKV = cute::tuple<int, _1, cute::tuple<int, int>>;
|
||||||
|
using LayoutQ = LayoutQKV;
|
||||||
|
using LayoutK = LayoutQKV;
|
||||||
|
using LayoutV = LayoutQKV;
|
||||||
|
|
||||||
|
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
Element, LayoutQ, Alignment,
|
||||||
|
Element, LayoutK, Alignment,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShapeQK, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
// the stride for A does not matter since we do not load from smem at all
|
||||||
|
Element, LayoutK, Alignment,
|
||||||
|
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShapePV, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
|
||||||
|
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
|
||||||
|
|
||||||
|
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
|
||||||
|
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
|
||||||
|
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
|
||||||
|
|
||||||
|
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||||
|
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||||
|
|
||||||
|
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||||
|
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||||
|
|
||||||
|
using TileShapeOut = TileShapePV;
|
||||||
|
using TiledMmaOut = TiledMmaPV;
|
||||||
|
using ElementOut = ElementAccumulator;
|
||||||
|
|
||||||
|
struct SharedStorage {
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||||
|
union {
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
const Element* ptr_Q;
|
||||||
|
LayoutQ dQ;
|
||||||
|
const Element* ptr_K;
|
||||||
|
LayoutK dK;
|
||||||
|
const Element* ptr_V;
|
||||||
|
LayoutV dV;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||||
|
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||||
|
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
TMA_Q tma_load_q;
|
||||||
|
TMA_K tma_load_k;
|
||||||
|
TMA_V tma_load_v;
|
||||||
|
|
||||||
|
float scale_softmax;
|
||||||
|
float scale_softmax_log2;
|
||||||
|
float rp_dropout;
|
||||||
|
};
|
||||||
|
|
||||||
|
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
|
||||||
|
cutlass::fmha::collective::LoadKind::kQ,
|
||||||
|
MainloopPipelineQ,
|
||||||
|
Element,
|
||||||
|
SmemLayoutQ,
|
||||||
|
TMA_Q
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
|
||||||
|
cutlass::fmha::collective::LoadKind::kK,
|
||||||
|
MainloopPipeline,
|
||||||
|
Element,
|
||||||
|
SmemLayoutK,
|
||||||
|
TMA_K
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
|
||||||
|
cutlass::fmha::collective::LoadKind::kV,
|
||||||
|
MainloopPipeline,
|
||||||
|
Element,
|
||||||
|
SmemLayoutV,
|
||||||
|
TMA_V
|
||||||
|
>;
|
||||||
|
|
||||||
|
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
|
||||||
|
|
||||||
|
static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{});
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||||
|
return true
|
||||||
|
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||||
|
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||||
|
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||||
|
|
||||||
|
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||||
|
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
|
||||||
|
typename CollectiveMmaQK::Arguments {
|
||||||
|
args.ptr_Q, args.dQ,
|
||||||
|
args.ptr_K, args.dK,
|
||||||
|
}, /*workspace=*/ nullptr);
|
||||||
|
|
||||||
|
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||||
|
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
|
||||||
|
typename CollectiveMmaPV::Arguments {
|
||||||
|
args.ptr_K, args.dK, // never used, dummy
|
||||||
|
args.ptr_V, select<1,0,2>(args.dV),
|
||||||
|
}, /*workspace=*/ nullptr);
|
||||||
|
|
||||||
|
return Params{
|
||||||
|
params_qk.tma_load_a,
|
||||||
|
params_qk.tma_load_b,
|
||||||
|
params_pv.tma_load_b,
|
||||||
|
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||||
|
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
|
||||||
|
1.0f
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void prefetch_tma_descriptors(Params const& params) {
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemShape>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
compute(
|
||||||
|
int block_rank_in_cluster,
|
||||||
|
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write,
|
||||||
|
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q,
|
||||||
|
SharedStorage& storage)
|
||||||
|
{
|
||||||
|
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||||
|
int thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
PipelineState smem_pipe_release = smem_pipe_read;
|
||||||
|
[[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
|
||||||
|
|
||||||
|
|
||||||
|
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||||
|
|
||||||
|
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||||
|
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1);
|
||||||
|
|
||||||
|
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
|
||||||
|
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
|
||||||
|
|
||||||
|
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
|
||||||
|
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
|
||||||
|
|
||||||
|
// Set predicate for the lowest lane_id in the warp
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
// Issue TmaLoads (Prologue fetches)
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
auto q_tile_iter = cute::make_coord_iterator(1);
|
||||||
|
int q_tile_count = 1;
|
||||||
|
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over K elems
|
||||||
|
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
|
||||||
|
|
||||||
|
int k_tile_count_tma = 2 * fusion_tile_count;
|
||||||
|
|
||||||
|
uint16_t mcast_mask_b = 0;
|
||||||
|
|
||||||
|
if (warp_idx == 0 && lane_predicate == 1) {
|
||||||
|
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||||
|
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
|
||||||
|
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||||
|
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < StageCount; i++) {
|
||||||
|
if (i % 2 == 0) {
|
||||||
|
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||||
|
} else {
|
||||||
|
load_v.template step<true>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TiledMmaQK tiled_mma_qk;
|
||||||
|
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||||
|
|
||||||
|
// Mainloop setup QK
|
||||||
|
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||||
|
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||||
|
|
||||||
|
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
|
||||||
|
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
|
||||||
|
|
||||||
|
// Prepare: MMA PV
|
||||||
|
TiledMmaPV tiled_mma_pv;
|
||||||
|
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
|
||||||
|
|
||||||
|
// Mainloop setup PV
|
||||||
|
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||||
|
|
||||||
|
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
|
||||||
|
|
||||||
|
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||||
|
|
||||||
|
pipeline_q.consumer_wait(smem_pipe_read_q);
|
||||||
|
|
||||||
|
// mapping into QK accumulator
|
||||||
|
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
|
||||||
|
Tensor tPcP = thr_mma_qk.partition_C(cP);
|
||||||
|
int m_block = get<0>(blk_coord);
|
||||||
|
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
|
||||||
|
|
||||||
|
// Allocate PV acc
|
||||||
|
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
|
||||||
|
|
||||||
|
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulator, Fusion, decltype(params)> softmax{params};
|
||||||
|
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
|
||||||
|
|
||||||
|
if (true)
|
||||||
|
{
|
||||||
|
--k_tile_count;
|
||||||
|
// Allocate QK acc
|
||||||
|
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA QK
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
|
||||||
|
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
|
||||||
|
|
||||||
|
Tensor acc_qk_fixed = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})));
|
||||||
|
|
||||||
|
Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout());
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_qk); i++) {
|
||||||
|
acc_qk_input(i) = static_cast<Element>(acc_qk(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA PV
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
warpgroup_fence_operand(acc_qk_fixed);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
//
|
||||||
|
// Advance the pipe
|
||||||
|
//
|
||||||
|
|
||||||
|
// Advance consumer pipeline
|
||||||
|
++smem_pipe_read;
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for ( ; k_tile_count > 0; --k_tile_count)
|
||||||
|
{
|
||||||
|
// Allocate QK acc
|
||||||
|
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA QK
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
|
||||||
|
softmax.template step_interleave_begin<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA PV
|
||||||
|
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
|
||||||
|
|
||||||
|
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
|
||||||
|
|
||||||
|
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<2>(tOrV); i++) {
|
||||||
|
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
|
||||||
|
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
|
||||||
|
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
|
||||||
|
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size(acc_qk_element_mk); j++) {
|
||||||
|
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
|
||||||
|
}
|
||||||
|
warpgroup_arrive();
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(tOrV); j++) {
|
||||||
|
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||||
|
}
|
||||||
|
|
||||||
|
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for ( ; k_tile_count > 0; --k_tile_count)
|
||||||
|
{
|
||||||
|
// Allocate QK acc
|
||||||
|
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA QK
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
|
||||||
|
softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA PV
|
||||||
|
auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}));
|
||||||
|
|
||||||
|
Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input);
|
||||||
|
|
||||||
|
static_assert(decltype(size<1>(layout_qk_input) == _1{})::value);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<2>(tOrV); i++) {
|
||||||
|
Tensor acc_qk_element = make_fragment_like<Element>(layout_qk_input(_, _0{}, _0{}));
|
||||||
|
Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element);
|
||||||
|
Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i));
|
||||||
|
softmax.step_interleave_step(acc_qk_input_mk, softmax_state);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size(acc_qk_element_mk); j++) {
|
||||||
|
acc_qk_element_mk(j) = static_cast<Element>(acc_qk_input_mk(j));
|
||||||
|
}
|
||||||
|
warpgroup_arrive();
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < size<1>(tOrV); j++) {
|
||||||
|
cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
|
||||||
|
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
|
||||||
|
|
||||||
|
return make_tuple(acc_pv, lse);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
|
|
||||||
@ -0,0 +1,560 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "../collective/fmha_common.hpp"
|
||||||
|
#include "../collective/fmha_collective_load.hpp"
|
||||||
|
#include "../collective/fmha_collective_softmax.hpp"
|
||||||
|
#include "../kernel/fmha_options.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using cutlass::fmha::kernel::Tag;
|
||||||
|
using cutlass::fmha::kernel::find_option_t;
|
||||||
|
|
||||||
|
template<
|
||||||
|
class Element_,
|
||||||
|
class ElementAccumulatorQK_,
|
||||||
|
class ElementAccumulatorPV_,
|
||||||
|
class TileShape_, // SeqQ, SeqKV, Head
|
||||||
|
class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches)
|
||||||
|
class Fusion,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaMainloopTmaWarpSpecialized {
|
||||||
|
|
||||||
|
using Element = Element_;
|
||||||
|
using ElementAccumulatorQK = ElementAccumulatorQK_;
|
||||||
|
using ElementAccumulatorPV = ElementAccumulatorPV_;
|
||||||
|
using TileShape = TileShape_;
|
||||||
|
|
||||||
|
using LayoutQ = LayoutQ_;
|
||||||
|
using LayoutK = LayoutK_;
|
||||||
|
using LayoutV = LayoutV_;
|
||||||
|
|
||||||
|
// Options
|
||||||
|
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
|
||||||
|
static constexpr bool kIsMainloopLocked = find_option_t<Tag::kIsMainloopLocked, false_type, Options...>::value;
|
||||||
|
|
||||||
|
static constexpr int NumLoadWarpGroups = 1;
|
||||||
|
static constexpr int NumMmaWarpGroups = find_option_t<Tag::kNumMmaWarpGroups, Int<2>, Options...>::value;
|
||||||
|
static constexpr int StageCount = find_option_t<Tag::kStagesKV, Int<5>, Options...>::value;
|
||||||
|
static constexpr int StageCountQ = find_option_t<Tag::kStagesQ, Int<NumMmaWarpGroups>, Options...>::value;
|
||||||
|
|
||||||
|
static const int kOuterLoads = 1;
|
||||||
|
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
|
||||||
|
using Stages = cutlass::gemm::collective::StageCount<StageCount>;
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
static_assert(StagesQ::value >= NumMmaWarpGroups);
|
||||||
|
static_assert(Stages::value >= 2);
|
||||||
|
|
||||||
|
// 16B alignment lets us use TMA
|
||||||
|
static constexpr int Alignment = 16 / sizeof(Element);
|
||||||
|
|
||||||
|
using TileShapeQK = Shape<
|
||||||
|
decltype(tuple_element_t<0, TileShape>{} / Int<NumMmaWarpGroups>{}),
|
||||||
|
tuple_element_t<1, TileShape>,
|
||||||
|
tuple_element_t<2, TileShape>>;
|
||||||
|
|
||||||
|
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
|
||||||
|
|
||||||
|
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
Element, LayoutQ, Alignment,
|
||||||
|
Element, LayoutK, Alignment,
|
||||||
|
ElementAccumulatorQK,
|
||||||
|
TileShapeQK, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
// the stride for A does not matter since we do not load from smem at all
|
||||||
|
Element, LayoutK, Alignment,
|
||||||
|
Element, decltype(select<1,0,2>(LayoutV{})), Alignment,
|
||||||
|
ElementAccumulatorPV,
|
||||||
|
TileShapePV, ClusterShape, Stages,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp;
|
||||||
|
|
||||||
|
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
|
||||||
|
using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{}));
|
||||||
|
|
||||||
|
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StagesQ::value>{}));
|
||||||
|
using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB;
|
||||||
|
using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB;
|
||||||
|
|
||||||
|
using MainloopPipeline = cutlass::PipelineTmaAsync<Stages::value>;
|
||||||
|
using MainloopPipelineQ = cutlass::PipelineTmaAsync<StagesQ::value>;
|
||||||
|
|
||||||
|
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||||
|
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||||
|
|
||||||
|
static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element);
|
||||||
|
static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element);
|
||||||
|
|
||||||
|
using TileShapeOut = TileShapePV;
|
||||||
|
using TiledMmaOut = TiledMmaPV;
|
||||||
|
using ElementOut = ElementAccumulatorPV;
|
||||||
|
|
||||||
|
struct SharedStorage {
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||||
|
union {
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||||
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
const Element* ptr_Q;
|
||||||
|
LayoutQ dQ;
|
||||||
|
const Element* ptr_K;
|
||||||
|
LayoutK dK;
|
||||||
|
const Element* ptr_V;
|
||||||
|
LayoutV dV;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
|
||||||
|
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
|
||||||
|
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
TMA_Q tma_load_q;
|
||||||
|
TMA_K tma_load_k;
|
||||||
|
TMA_V tma_load_v;
|
||||||
|
|
||||||
|
float scale_softmax;
|
||||||
|
float scale_softmax_log2;
|
||||||
|
float rp_dropout;
|
||||||
|
};
|
||||||
|
|
||||||
|
using LoadQ = cutlass::fmha::collective::CollectiveLoadTma<
|
||||||
|
cutlass::fmha::collective::LoadKind::kQ,
|
||||||
|
MainloopPipelineQ,
|
||||||
|
Element,
|
||||||
|
SmemLayoutQ,
|
||||||
|
TMA_Q
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadK = cutlass::fmha::collective::CollectiveLoadTma<
|
||||||
|
cutlass::fmha::collective::LoadKind::kK,
|
||||||
|
MainloopPipeline,
|
||||||
|
Element,
|
||||||
|
SmemLayoutK,
|
||||||
|
TMA_K
|
||||||
|
>;
|
||||||
|
|
||||||
|
using LoadV = cutlass::fmha::collective::CollectiveLoadTma<
|
||||||
|
cutlass::fmha::collective::LoadKind::kV,
|
||||||
|
MainloopPipeline,
|
||||||
|
Element,
|
||||||
|
SmemLayoutV,
|
||||||
|
TMA_V
|
||||||
|
>;
|
||||||
|
|
||||||
|
static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{}));
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
|
||||||
|
return true
|
||||||
|
&& (get<4>(problem_size) <= get<2>(TileShape{}))
|
||||||
|
&& ((get<4>(problem_size) % Alignment) == 0)
|
||||||
|
&& ((get<2>(problem_size) % Alignment) == 0)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) {
|
||||||
|
|
||||||
|
auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||||
|
auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk,
|
||||||
|
typename CollectiveMmaQK::Arguments {
|
||||||
|
args.ptr_Q, args.dQ,
|
||||||
|
args.ptr_K, args.dK,
|
||||||
|
}, /*workspace=*/ nullptr);
|
||||||
|
|
||||||
|
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
|
||||||
|
auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv,
|
||||||
|
typename CollectiveMmaPV::Arguments {
|
||||||
|
args.ptr_K, args.dK, // never used, dummy
|
||||||
|
args.ptr_V, select<1,0,2>(args.dV),
|
||||||
|
}, /*workspace=*/ nullptr);
|
||||||
|
|
||||||
|
return Params{
|
||||||
|
params_qk.tma_load_a,
|
||||||
|
params_qk.tma_load_b,
|
||||||
|
params_pv.tma_load_b,
|
||||||
|
1.0f / (float) std::sqrt(get<4>(problem_size)),
|
||||||
|
(float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))),
|
||||||
|
1.0f
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void prefetch_tma_descriptors(Params const& params) {
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
|
||||||
|
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool kLoadQ, class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load_kv_maybe_q(
|
||||||
|
int block_rank_in_cluster,
|
||||||
|
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipeline& pipeline, PipelineState& smem_pipe_write,
|
||||||
|
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
|
||||||
|
SharedStorage& storage,
|
||||||
|
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||||
|
{
|
||||||
|
int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size);
|
||||||
|
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
uint16_t mcast_mask_b = 0;
|
||||||
|
|
||||||
|
if (lane_predicate == 1) {
|
||||||
|
if constexpr (cute::is_same_v<typename CollectiveMmaQK::GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||||
|
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
|
||||||
|
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||||
|
mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
|
||||||
|
[[maybe_unused]] int q_tile_count = NumMmaWarpGroups;
|
||||||
|
|
||||||
|
auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count);
|
||||||
|
int k_tile_count = 2 * fusion_tile_count;
|
||||||
|
|
||||||
|
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||||
|
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
|
||||||
|
|
||||||
|
LoadK load_k{params.tma_load_k, pipeline, storage.smem_k};
|
||||||
|
auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count);
|
||||||
|
|
||||||
|
LoadV load_v{params.tma_load_v, pipeline, storage.smem_v};
|
||||||
|
auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count);
|
||||||
|
|
||||||
|
if constexpr (kLoadQ) {
|
||||||
|
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||||
|
|
||||||
|
if constexpr (kLoadQ) {
|
||||||
|
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (! kLoadQ) {
|
||||||
|
if (do_barrier) {
|
||||||
|
load_warp_barrier.arrive();
|
||||||
|
load_warp_barrier.wait(/*phase=*/ 0);
|
||||||
|
do_barrier = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||||
|
|
||||||
|
if constexpr (kLoadQ) {
|
||||||
|
while (q_tile_count > 0) {
|
||||||
|
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
while (k_tile_count > 0) {
|
||||||
|
load_k.template step<false>(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||||
|
load_v.template step<true>(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemShape, class LoadWarpBarrier>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
load_maybe_q(
|
||||||
|
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q,
|
||||||
|
SharedStorage& storage,
|
||||||
|
LoadWarpBarrier& load_warp_barrier, bool do_barrier)
|
||||||
|
{
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q};
|
||||||
|
auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups);
|
||||||
|
|
||||||
|
auto q_tile_iter = cute::make_coord_iterator(Int<NumMmaWarpGroups>{});
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) {
|
||||||
|
int count = 1;
|
||||||
|
load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count);
|
||||||
|
if (q_tile_count == 0 && do_barrier) {
|
||||||
|
load_warp_barrier.arrive();
|
||||||
|
load_warp_barrier.wait(/*phase=*/ 0);
|
||||||
|
do_barrier = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer>
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
reduce(
|
||||||
|
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer,
|
||||||
|
SharedStorage& storage)
|
||||||
|
{ /* no-op */ }
|
||||||
|
|
||||||
|
template<class BlkCoord, class ProblemShape, class MainloopPipelineReducer, class PipelineStateReducer, class MathWgOrderBarrier>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
compute(
|
||||||
|
BlkCoord const& blk_coord, BlkCoord const& wg_coord,
|
||||||
|
Params const& params, ProblemShape const& problem_size,
|
||||||
|
MainloopPipeline& pipeline, PipelineState& smem_pipe_read,
|
||||||
|
MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q,
|
||||||
|
MainloopPipelineReducer&, PipelineStateReducer&,
|
||||||
|
SharedStorage& storage,
|
||||||
|
MathWgOrderBarrier& math_wg_order_barrier)
|
||||||
|
{
|
||||||
|
int thread_idx = int(threadIdx.x);
|
||||||
|
|
||||||
|
PipelineState smem_pipe_release = smem_pipe_read;
|
||||||
|
PipelineStateQ smem_pipe_release_q = smem_pipe_read_q;
|
||||||
|
|
||||||
|
TiledMmaQK tiled_mma_qk;
|
||||||
|
auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||||
|
|
||||||
|
// Mainloop setup QK
|
||||||
|
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
|
||||||
|
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
|
||||||
|
|
||||||
|
Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE)
|
||||||
|
Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE)
|
||||||
|
|
||||||
|
// Prepare: MMA PV
|
||||||
|
TiledMmaPV tiled_mma_pv;
|
||||||
|
auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);
|
||||||
|
|
||||||
|
// Mainloop setup PV
|
||||||
|
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
|
||||||
|
|
||||||
|
Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE)
|
||||||
|
Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE)
|
||||||
|
|
||||||
|
int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||||
|
|
||||||
|
pipeline_q.consumer_wait(smem_pipe_read_q);
|
||||||
|
|
||||||
|
// mapping into QK accumulator
|
||||||
|
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{}));
|
||||||
|
Tensor tPcP = thr_mma_qk.partition_C(cP);
|
||||||
|
int m_block = get<0>(wg_coord);
|
||||||
|
tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{});
|
||||||
|
|
||||||
|
// Allocate PV acc
|
||||||
|
Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{}));
|
||||||
|
|
||||||
|
cutlass::fmha::collective::CollectiveSoftmax<ElementAccumulatorQK, Fusion, decltype(params)> softmax{params};
|
||||||
|
auto softmax_state = softmax.init(acc_pv, tiled_mma_pv);
|
||||||
|
|
||||||
|
if (true)
|
||||||
|
{
|
||||||
|
--k_tile_count;
|
||||||
|
// Allocate QK acc
|
||||||
|
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
math_wg_order_barrier.wait();
|
||||||
|
|
||||||
|
// MMA QK
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
math_wg_order_barrier.arrive();
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
|
||||||
|
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size);
|
||||||
|
|
||||||
|
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA PV
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
warpgroup_fence_operand(acc_qk_fixed);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
// Advance consumer pipeline
|
||||||
|
++smem_pipe_read;
|
||||||
|
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
while (k_tile_count > 0)
|
||||||
|
{
|
||||||
|
--k_tile_count;
|
||||||
|
|
||||||
|
// Allocate QK acc
|
||||||
|
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA QK
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
|
||||||
|
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
|
||||||
|
softmax.template step<false>(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||||
|
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
|
||||||
|
|
||||||
|
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read, tok);
|
||||||
|
|
||||||
|
// MMA PV
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
warpgroup_fence_operand(acc_qk_fixed);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||||
|
}
|
||||||
|
|
||||||
|
k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
while (k_tile_count > 0)
|
||||||
|
{
|
||||||
|
--k_tile_count;
|
||||||
|
|
||||||
|
// Allocate QK acc
|
||||||
|
Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{}));
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// MMA QK
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
auto tok = pipeline.consumer_try_wait(smem_pipe_read);
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_qk);
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
|
||||||
|
//if constexpr (kIsPersistent)
|
||||||
|
// if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q);
|
||||||
|
|
||||||
|
if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait();
|
||||||
|
softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size);
|
||||||
|
if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive();
|
||||||
|
|
||||||
|
Tensor acc_qk_fixed = make_acc_into_op<Element>(acc_qk, typename TiledMmaPV::LayoutA_TV{});
|
||||||
|
|
||||||
|
pipeline.consumer_wait(smem_pipe_read, tok);
|
||||||
|
|
||||||
|
// MMA PV
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
warpgroup_fence_operand(acc_qk_fixed);
|
||||||
|
warpgroup_arrive();
|
||||||
|
|
||||||
|
cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv);
|
||||||
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
++smem_pipe_read;
|
||||||
|
tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q);
|
||||||
|
|
||||||
|
// Wait for the pipeline MMAs to drain
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
warpgroup_fence_operand(acc_pv);
|
||||||
|
|
||||||
|
if (kIsPersistent) pipeline.consumer_release(smem_pipe_release);
|
||||||
|
++smem_pipe_release;
|
||||||
|
|
||||||
|
Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv);
|
||||||
|
|
||||||
|
return make_tuple(acc_pv, lse);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
245
examples/88_hopper_fmha/collective/fmha_common.hpp
Normal file
245
examples/88_hopper_fmha/collective/fmha_common.hpp
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/kernel_hardware_info.h"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template<typename Atom, typename TA, typename TB, typename TC>
|
||||||
|
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||||
|
constexpr int rA = decltype(rank(tA))::value;
|
||||||
|
constexpr int rB = decltype(rank(tB))::value;
|
||||||
|
constexpr int rC = decltype(rank(tC))::value;
|
||||||
|
if constexpr (rA == 2 && rB == 2 && rC == 1) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_block = 0; k_block < size<1>(tA); k_block++) {
|
||||||
|
cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC);
|
||||||
|
atom.accumulate_ = GMMA::ScaleOut::One;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(rA == 3 && rB == 3 && rC == 3);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
|
||||||
|
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
|
||||||
|
atom.accumulate_ = GMMA::ScaleOut::One;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Atom, typename TA, typename TB, typename TC>
|
||||||
|
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
|
||||||
|
atom.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
gemm_reset_zero_acc(atom, tA, tB, tC);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename Fn>
|
||||||
|
CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) {
|
||||||
|
if constexpr (decltype(size(t) % _2{} == _0{})::value) {
|
||||||
|
auto partial = make_tensor<typename T::value_type>(size(t) / _2{});
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int i = 0; i < size(partial); i++) {
|
||||||
|
partial(i) = fn(t(i), t(i + size(partial)));
|
||||||
|
}
|
||||||
|
return reduce(partial, fn);
|
||||||
|
} else {
|
||||||
|
auto result = t(_0{});
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int i = 1; i < size(t); i++) {
|
||||||
|
result = fn(result, t(i));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct fmha_max {
|
||||||
|
CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Threshold, typename Source, typename Reference>
|
||||||
|
inline auto __device__ constexpr layout_separate(Threshold const& thr,
|
||||||
|
Source const& src, Reference const& ref) {
|
||||||
|
auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
|
||||||
|
if constexpr(decltype(r < thr)::value) {
|
||||||
|
return s;
|
||||||
|
} else {
|
||||||
|
return make_layout(_1{}, _0{});
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) {
|
||||||
|
if constexpr(decltype(r >= thr)::value) {
|
||||||
|
return s;
|
||||||
|
} else {
|
||||||
|
return make_layout(_1{}, _0{});
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
return make_layout(lt, ge);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TiledMma, typename Acc>
|
||||||
|
inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) {
|
||||||
|
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||||
|
get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{}));
|
||||||
|
auto V_M = get<0>(separated);
|
||||||
|
auto V_N = get<1>(separated);
|
||||||
|
return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TiledMma, typename Acc>
|
||||||
|
inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) {
|
||||||
|
return layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||||
|
get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{}));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TiledMma, typename Acc>
|
||||||
|
inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) {
|
||||||
|
return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout()));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TiledMma>
|
||||||
|
inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) {
|
||||||
|
auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}),
|
||||||
|
make_layout(shape<0>(typename TiledMma::LayoutC_TV{})),
|
||||||
|
stride<0>(typename TiledMma::LayoutC_TV{}));
|
||||||
|
return get<1>(separated);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<template<cute::GMMA::Major, cute::GMMA::Major, cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
|
||||||
|
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<tA, tB, sA, sB>> const& tiled_mma) {
|
||||||
|
using Atom = cute::MMA_Atom<Primitive<tA, tB, sA, sB>>;
|
||||||
|
using ElementA = typename Atom::ValTypeA;
|
||||||
|
using ElementB = typename Atom::ValTypeB;
|
||||||
|
using ElementC = typename Atom::ValTypeC;
|
||||||
|
using Shape_MNK = typename Atom::Shape_MNK;
|
||||||
|
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
|
||||||
|
return cute::MMA_Atom<RS>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<template<cute::GMMA::ScaleIn, cute::GMMA::ScaleIn> class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB>
|
||||||
|
inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom<Primitive<sA, sB>> const& tiled_mma) {
|
||||||
|
using Atom = cute::MMA_Atom<Primitive<sA, sB>>;
|
||||||
|
using ElementA = typename Atom::ValTypeA;
|
||||||
|
using ElementB = typename Atom::ValTypeB;
|
||||||
|
using ElementC = typename Atom::ValTypeC;
|
||||||
|
using Shape_MNK = typename Atom::Shape_MNK;
|
||||||
|
constexpr auto tA = cute::GMMA::Major::K;
|
||||||
|
constexpr auto tB = cute::GMMA::Major::K;
|
||||||
|
using RS = decltype(cute::GMMA::rs_op_selector<ElementA, ElementB, ElementC, Shape_MNK, tA, tB, sA, sB>());
|
||||||
|
return cute::MMA_Atom<RS>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class Atom, class... Args>
|
||||||
|
CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA<Atom, Args...> const& tiled_mma) {
|
||||||
|
return cute::TiledMMA<decltype(convert_to_gmma_rs(Atom{})), Args...>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename CLayout, typename AValueShape>
|
||||||
|
CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) {
|
||||||
|
return make_layout(
|
||||||
|
make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))),
|
||||||
|
make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c))));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class Layout, class Stages = _1>
|
||||||
|
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
|
||||||
|
return composition(layout, make_tuple(_, _, make_layout(stages)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class Element, class Accumulator, class OperandLayout_TV>
|
||||||
|
CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) {
|
||||||
|
Tensor operand = make_fragment_like<Element>(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv)));
|
||||||
|
Tensor operand_as_acc = make_tensor(operand.data(), acc.layout());
|
||||||
|
|
||||||
|
cute::copy(acc, operand_as_acc);
|
||||||
|
|
||||||
|
if constexpr (sizeof(Element) == 1) {
|
||||||
|
|
||||||
|
// 00 11 22 33 00 11 22 33 acc layout
|
||||||
|
// 00 00 11 11 22 22 33 33 operand layout
|
||||||
|
// BB AA AA BB AA BB BB AA conflict-free exchange pattern
|
||||||
|
// 16-bit exchange; so process two at a time potentially
|
||||||
|
int tid = threadIdx.x % 4;
|
||||||
|
auto values_u32 = recast<uint32_t>(operand);
|
||||||
|
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int n = 0; n < size<1>(values_u32); n++) {
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int k = 0; k < size<2>(values_u32); k++) {
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int ii = 0; ii < 8; ii += 4) {
|
||||||
|
|
||||||
|
uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k);
|
||||||
|
uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k);
|
||||||
|
|
||||||
|
// step A:
|
||||||
|
// t 1 v 0 -> t 0 v 1
|
||||||
|
// t 2 v 0 -> t 1 v 0
|
||||||
|
// t 0 v 1 -> t 2 v 0
|
||||||
|
// t 3 v 1 -> t 3 v 1
|
||||||
|
|
||||||
|
int v_to_send = tid == 1 || tid == 2 ? 0 : 1;
|
||||||
|
int v_to_recv = v_to_send;
|
||||||
|
int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF;
|
||||||
|
|
||||||
|
uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
|
||||||
|
|
||||||
|
values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4);
|
||||||
|
|
||||||
|
// step B:
|
||||||
|
// t 0 v 0 -> t 0 v 0
|
||||||
|
// t 3 v 0 -> t 1 v 1
|
||||||
|
// t 1 v 1 -> t 2 v 1
|
||||||
|
// t 2 v 1 -> t 3 v 0
|
||||||
|
|
||||||
|
v_to_send = 1 - v_to_send;
|
||||||
|
v_to_recv = 1 - v_to_recv;
|
||||||
|
t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF;
|
||||||
|
|
||||||
|
uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1;
|
||||||
|
|
||||||
|
values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4);
|
||||||
|
|
||||||
|
values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410);
|
||||||
|
values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return operand;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
156
examples/88_hopper_fmha/collective/fmha_epilogue.hpp
Normal file
156
examples/88_hopper_fmha/collective/fmha_epilogue.hpp
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||||
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||||
|
|
||||||
|
#include "../collective/fmha_common.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
template<class Element, class ElementAccumulator, class TileShape_WG>
|
||||||
|
struct FmhaFwdEpilogue {
|
||||||
|
|
||||||
|
static constexpr int Alignment = 16 / sizeof(Element);
|
||||||
|
|
||||||
|
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
|
||||||
|
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator, ElementAccumulator,
|
||||||
|
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||||
|
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||||
|
cutlass::epilogue::TmaWarpSpecialized,
|
||||||
|
DefaultOperation
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
Element* ptr_O;
|
||||||
|
cute::tuple<int, cute::_1, cute::tuple<int, int>> dO;
|
||||||
|
|
||||||
|
ElementAccumulator* ptr_LSE;
|
||||||
|
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ElementAccumulator* ptr_LSE;
|
||||||
|
cute::tuple<cute::_1, cute::tuple<int, int>> dLSE;
|
||||||
|
|
||||||
|
typename CollectiveEpilogueTMA::Params epilogue_TMA;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage;
|
||||||
|
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
|
||||||
|
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
|
||||||
|
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
|
||||||
|
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1,
|
||||||
|
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||||
|
typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO};
|
||||||
|
return Params{
|
||||||
|
args.ptr_LSE, args.dLSE,
|
||||||
|
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
|
||||||
|
CUTLASS_DEVICE void operator()(
|
||||||
|
TileShape const& tile_shape, BlkCoord const& blk_coord,
|
||||||
|
ResultTuple const& result, TiledMma const& tiled_mma,
|
||||||
|
ProblemShape const& problem_size, Params const& params,
|
||||||
|
LoadPipeline epi_load_pipeline,
|
||||||
|
TensorStorage& epi_tensor_storage)
|
||||||
|
{
|
||||||
|
using X = Underscore;
|
||||||
|
|
||||||
|
auto acc = get<0>(result);
|
||||||
|
auto lse = get<1>(result);
|
||||||
|
|
||||||
|
auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
|
||||||
|
|
||||||
|
int seqlen_q = get<2>(problem_size);
|
||||||
|
int num_batch = get<0>(problem_size);
|
||||||
|
int num_heads = get<1>(problem_size);
|
||||||
|
// Epilogue for lse
|
||||||
|
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE),
|
||||||
|
make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)),
|
||||||
|
make_stride(_1{}, _0{}, get<1>(params.dLSE)));
|
||||||
|
Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{});
|
||||||
|
Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord));
|
||||||
|
Tensor tOgLSE = thr_mma.partition_C(gLSE);
|
||||||
|
Tensor cO = make_identity_tensor(take<0,2>(tile_shape));
|
||||||
|
Tensor tOcO = thr_mma.partition_C(cO);
|
||||||
|
if (get<1>(tOcO(_0{})) == 0) {
|
||||||
|
auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout()));
|
||||||
|
auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout()));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size<0>(tOgLSE_mn); i++) {
|
||||||
|
if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) {
|
||||||
|
tOgLSE_mn(i, _0{}) = lse(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _,
|
||||||
|
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||||
|
|
||||||
|
CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage);
|
||||||
|
|
||||||
|
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
|
||||||
|
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||||
|
epi_store_pipeline_params.always_wait = true;
|
||||||
|
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||||
|
|
||||||
|
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
|
||||||
|
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||||
|
|
||||||
|
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||||
|
epilogue_tma.store(
|
||||||
|
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||||
|
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||||
|
problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||||
|
acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||||
|
epi_tensor_storage
|
||||||
|
);
|
||||||
|
|
||||||
|
epilogue_tma.store_tail(
|
||||||
|
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||||
|
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
157
examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp
Normal file
157
examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||||
|
|
||||||
|
#include "../collective/fmha_epilogue.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
template<class Element, class ElementAccumulator, class TileShape_WG>
|
||||||
|
struct FmhaBwdEpilogueKV {
|
||||||
|
|
||||||
|
static constexpr int Alignment = 16 / sizeof(Element);
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
Element* ptr_K;
|
||||||
|
cute::tuple<int, int, int, cute::_1> dK;
|
||||||
|
|
||||||
|
Element* ptr_V;
|
||||||
|
cute::tuple<int, int, int, _1> dV;
|
||||||
|
};
|
||||||
|
|
||||||
|
//using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<Element, ElementAccumulator, void>;
|
||||||
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||||
|
using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT<
|
||||||
|
cutlass::epilogue::fusion::Sm90Compute<cutlass::first, Element, ElementAccumulator, RoundStyle>,
|
||||||
|
cutlass::epilogue::fusion::Sm90AccFetch
|
||||||
|
>;
|
||||||
|
using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAccumulator, ElementAccumulator,
|
||||||
|
void, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||||
|
Element, cute::tuple<int, _1, cute::tuple<int, int>>, Alignment,
|
||||||
|
cutlass::epilogue::TmaWarpSpecialized,
|
||||||
|
DefaultOperation
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
typename CollectiveEpilogueTMA::Params epilogue_K;
|
||||||
|
typename CollectiveEpilogueTMA::Params epilogue_V;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2];
|
||||||
|
using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage;
|
||||||
|
using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline;
|
||||||
|
static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes;
|
||||||
|
|
||||||
|
template<class ProblemShape>
|
||||||
|
static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) {
|
||||||
|
auto dK = make_stride(get<2>(args.dK), get<3>(args.dK),
|
||||||
|
make_stride(get<0>(args.dK), get<1>(args.dK)));
|
||||||
|
auto dV = make_stride(get<2>(args.dV), get<3>(args.dV),
|
||||||
|
make_stride(get<0>(args.dV), get<1>(args.dV)));
|
||||||
|
|
||||||
|
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1,
|
||||||
|
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||||
|
typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK};
|
||||||
|
typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV};
|
||||||
|
return Params{
|
||||||
|
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr),
|
||||||
|
CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class TileShape, class BlkCoord, class ResultTuple, class TiledMma, class ProblemShape>
|
||||||
|
CUTLASS_DEVICE void operator()(
|
||||||
|
TileShape const& tile_shape, BlkCoord const& blk_coord,
|
||||||
|
ResultTuple const& result, TiledMma const& tiled_mma,
|
||||||
|
ProblemShape const& problem_size, Params const& params,
|
||||||
|
LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage)
|
||||||
|
{
|
||||||
|
auto acc_k = get<0>(result);
|
||||||
|
auto acc_v = get<1>(result);
|
||||||
|
|
||||||
|
auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _,
|
||||||
|
make_shape(get<0>(problem_size), get<1>(problem_size)));
|
||||||
|
|
||||||
|
using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline;
|
||||||
|
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||||
|
epi_store_pipeline_params.always_wait = true;
|
||||||
|
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||||
|
|
||||||
|
typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state;
|
||||||
|
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||||
|
|
||||||
|
CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]};
|
||||||
|
CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||||
|
epilogue_k.store(
|
||||||
|
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||||
|
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||||
|
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||||
|
acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||||
|
epi_tensor_storage[0]
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
|
||||||
|
epilogue_v.store(
|
||||||
|
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||||
|
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||||
|
problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)),
|
||||||
|
acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup,
|
||||||
|
epi_tensor_storage[1]
|
||||||
|
);
|
||||||
|
|
||||||
|
epilogue_k.store_tail(
|
||||||
|
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||||
|
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||||
|
);
|
||||||
|
|
||||||
|
epilogue_v.store_tail(
|
||||||
|
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||||
|
epi_store_pipeline, epi_store_pipe_producer_state_next
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
283
examples/88_hopper_fmha/collective/fmha_fusion.hpp
Normal file
283
examples/88_hopper_fmha/collective/fmha_fusion.hpp
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::collective {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
struct DefaultFusion {
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return ceil_div(get<3>(problem_size), get<1>(tile_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_masked_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return get_trip_count(blk_coord, tile_shape, problem_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_unmasked_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class AccQK, class IndexQK, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void before_softmax(
|
||||||
|
AccQK& acc_qk,
|
||||||
|
IndexQK const& index_qk,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ResidualFusion : DefaultFusion {
|
||||||
|
|
||||||
|
using Base = DefaultFusion;
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_masked_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_unmasked_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class AccQK, class IndexQK, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void before_softmax(
|
||||||
|
AccQK& acc_qk,
|
||||||
|
IndexQK const& index_qk,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
// This is useful is seqlen_k % kBlockN != 0 since it masks
|
||||||
|
// the remaining elements out from softmax.
|
||||||
|
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
|
||||||
|
// issues as they are transparently taken care of by TMA and the
|
||||||
|
// epilogue, if it is instantiated with predication support.
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_qk); i++) {
|
||||||
|
auto pos = index_qk(i);
|
||||||
|
if (get<1>(pos) >= get<3>(problem_size)) {
|
||||||
|
acc_qk(i) = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CausalFusion : DefaultFusion {
|
||||||
|
|
||||||
|
using Base = DefaultFusion;
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
// See note below on different ways to think about causal attention
|
||||||
|
// Again, we'd add the offset_q into the max_blocks_q calculation
|
||||||
|
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
|
||||||
|
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
|
||||||
|
return std::min(max_blocks_k, max_blocks_q);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_masked_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return ceil_div(get<0>(tile_shape), get<1>(tile_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_unmasked_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class AccQK, class IndexQK, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void before_softmax(
|
||||||
|
AccQK& acc_qk,
|
||||||
|
IndexQK const& index_qk,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
// There are two ways to do causal if N_Q != N_K
|
||||||
|
// (1) is to assume that the Q is at the beginning of the matrix
|
||||||
|
// - this is what we demonstrate here
|
||||||
|
// (2) is that it is at the end of the matrix
|
||||||
|
// - this is usually what we want for inference settings
|
||||||
|
// where we only compute the next row and use cache for the rest
|
||||||
|
// - if you'd like this, you only need to add an offset like so:
|
||||||
|
// get<0>(pos) + offset_q < get<1>(pos)
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_qk); i++) {
|
||||||
|
auto pos = index_qk(i);
|
||||||
|
if (get<0>(pos) < get<1>(pos)) {
|
||||||
|
acc_qk(i) = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class Base>
|
||||||
|
struct FusionBwdAdapter {
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class AccQK, class IndexQK, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void before_softmax(
|
||||||
|
AccQK& acc_qk,
|
||||||
|
IndexQK const& index_qk,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
auto index_base = index_qk(_0{});
|
||||||
|
auto index_shape = shape(index_qk);
|
||||||
|
auto index_stride = transform_leaf(stride(index_qk), [](auto elem) {
|
||||||
|
if constexpr (is_scaled_basis<decltype(elem)>::value) {
|
||||||
|
if constexpr(decltype(elem.mode() == _0{})::value) {
|
||||||
|
return ScaledBasis<decltype(elem.value()), 1>(elem.value());
|
||||||
|
} else {
|
||||||
|
return ScaledBasis<decltype(elem.value()), 0>(elem.value());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return elem;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride));
|
||||||
|
Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool is_contributing(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct FusionBwdAdapter<CausalFusion> {
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_trip_count(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
return get<2>(problem_size) / get<0>(TileShape{});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class AccQK, class IndexQK, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void before_softmax(
|
||||||
|
AccQK& acc_qk,
|
||||||
|
IndexQK const& index_qk,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
|
||||||
|
) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(acc_qk); i++) {
|
||||||
|
auto pos = index_qk(i);
|
||||||
|
if (get<1>(pos) < get<0>(pos)) {
|
||||||
|
acc_qk(i) = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class BlkCoord, class TileShape, class ProblemSize>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool is_contributing(
|
||||||
|
BlkCoord const& blk_coord,
|
||||||
|
TileShape const& tile_shape,
|
||||||
|
ProblemSize const& problem_size
|
||||||
|
) {
|
||||||
|
int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape);
|
||||||
|
int min_k = get<1>(blk_coord) * get<1>(tile_shape);
|
||||||
|
return min_k <= max_q;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::collective
|
||||||
278
examples/88_hopper_fmha/device/device_universal.hpp
Normal file
278
examples/88_hopper_fmha/device/device_universal.hpp
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
\file
|
||||||
|
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// common
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/device_kernel.h"
|
||||||
|
|
||||||
|
#if !defined(__CUDACC_RTC__)
|
||||||
|
#include "cutlass/cluster_launch.hpp"
|
||||||
|
#include "cutlass/trace.h"
|
||||||
|
#endif // !defined(__CUDACC_RTC__)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::device {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <class Kernel_>
|
||||||
|
class Universal {
|
||||||
|
public:
|
||||||
|
using Kernel = Kernel_;
|
||||||
|
|
||||||
|
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
|
||||||
|
|
||||||
|
/// Argument structure: User API
|
||||||
|
using Arguments = typename Kernel::Arguments;
|
||||||
|
/// Argument structure: Kernel API
|
||||||
|
using Params = typename Kernel::Params;
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
/// Kernel API parameters object
|
||||||
|
Params params_;
|
||||||
|
|
||||||
|
bool is_initialized(bool set = false) {
|
||||||
|
static bool initialized = false;
|
||||||
|
if (set) initialized = true;
|
||||||
|
return initialized;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Access the Params structure
|
||||||
|
Params const& params() const {
|
||||||
|
return params_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines whether the GEMM can execute the given problem.
|
||||||
|
static Status
|
||||||
|
can_implement(Arguments const& args) {
|
||||||
|
if (Kernel::can_implement(args)) {
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return Status::kInvalid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the workspace size
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(Arguments const& args) {
|
||||||
|
size_t workspace_bytes = 0;
|
||||||
|
workspace_bytes += Kernel::get_workspace_size(args);
|
||||||
|
return workspace_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the grid shape
|
||||||
|
static dim3
|
||||||
|
get_grid_shape(Params const& params) {
|
||||||
|
return Kernel::get_grid_shape(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the maximum number of active blocks per multiprocessor
|
||||||
|
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||||
|
CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()");
|
||||||
|
int max_active_blocks = -1;
|
||||||
|
int smem_size = Kernel::SharedStorageSize;
|
||||||
|
|
||||||
|
// first, account for dynamic smem capacity if needed
|
||||||
|
cudaError_t result;
|
||||||
|
if (smem_size >= (48 << 10)) {
|
||||||
|
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||||
|
result = cudaFuncSetAttribute(
|
||||||
|
device_kernel<Kernel>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem_size);
|
||||||
|
if (cudaSuccess != result) {
|
||||||
|
result = cudaGetLastError(); // to clear the error bit
|
||||||
|
CUTLASS_TRACE_HOST(
|
||||||
|
" cudaFuncSetAttribute() returned error: "
|
||||||
|
<< cudaGetErrorString(result));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// query occupancy after setting smem size
|
||||||
|
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
|
&max_active_blocks,
|
||||||
|
device_kernel<Kernel>,
|
||||||
|
Kernel::MaxThreadsPerBlock,
|
||||||
|
smem_size);
|
||||||
|
|
||||||
|
if (cudaSuccess != result) {
|
||||||
|
result = cudaGetLastError(); // to clear the error bit
|
||||||
|
CUTLASS_TRACE_HOST(
|
||||||
|
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||||
|
<< cudaGetErrorString(result));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||||
|
return max_active_blocks;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initializes GEMM state from arguments.
|
||||||
|
Status
|
||||||
|
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||||
|
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||||
|
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||||
|
|
||||||
|
// Initialize the workspace
|
||||||
|
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the Params structure
|
||||||
|
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||||
|
|
||||||
|
if (is_initialized()) return Status::kSuccess;
|
||||||
|
|
||||||
|
// account for dynamic smem capacity if needed
|
||||||
|
int smem_size = Kernel::SharedStorageSize;
|
||||||
|
if (smem_size >= (48 << 10)) {
|
||||||
|
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||||
|
cudaError_t result = cudaFuncSetAttribute(
|
||||||
|
device_kernel<Kernel>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem_size);
|
||||||
|
if (cudaSuccess != result) {
|
||||||
|
result = cudaGetLastError(); // to clear the error bit
|
||||||
|
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||||
|
return Status::kErrorInternal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
is_initialized(true);
|
||||||
|
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||||
|
Status
|
||||||
|
update(Arguments const& args, void* workspace = nullptr) {
|
||||||
|
CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace);
|
||||||
|
|
||||||
|
size_t workspace_bytes = get_workspace_size(args);
|
||||||
|
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||||
|
return Status::kErrorWorkspaceNull;
|
||||||
|
}
|
||||||
|
|
||||||
|
params_ = Kernel::to_underlying_arguments(args, workspace);
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||||
|
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||||
|
static Status
|
||||||
|
run(Params& params, cudaStream_t stream = nullptr) {
|
||||||
|
CUTLASS_TRACE_HOST("Universal::run()");
|
||||||
|
dim3 const block = Kernel::get_block_shape();
|
||||||
|
dim3 const grid = get_grid_shape(params);
|
||||||
|
|
||||||
|
// configure smem size and carveout
|
||||||
|
int smem_size = Kernel::SharedStorageSize;
|
||||||
|
|
||||||
|
Status launch_result;
|
||||||
|
// Use extended launch API only for mainloops that use it
|
||||||
|
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||||
|
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||||
|
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||||
|
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||||
|
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||||
|
void* kernel_params[] = {¶ms};
|
||||||
|
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
launch_result = Status::kSuccess;
|
||||||
|
cutlass::arch::synclog_setup();
|
||||||
|
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaError_t result = cudaGetLastError();
|
||||||
|
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||||
|
return Status::kErrorInternal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||||
|
Status
|
||||||
|
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||||
|
Status status = initialize(args, workspace, stream);
|
||||||
|
if (Status::kSuccess == status) {
|
||||||
|
status = run(params_, stream);
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||||
|
Status
|
||||||
|
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||||
|
return run(args, workspace, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||||
|
Status
|
||||||
|
run(cudaStream_t stream = nullptr) {
|
||||||
|
return run(params_, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||||
|
Status
|
||||||
|
operator()(cudaStream_t stream = nullptr) {
|
||||||
|
return run(params_, stream);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::device
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
299
examples/88_hopper_fmha/device/fmha_device_bwd.hpp
Normal file
299
examples/88_hopper_fmha/device/fmha_device_bwd.hpp
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/*!
|
||||||
|
\file
|
||||||
|
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// common
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "../device/device_universal.hpp"
|
||||||
|
#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp"
|
||||||
|
#include "../collective/fmha_fusion.hpp"
|
||||||
|
#include "../collective/fmha_epilogue_bwd.hpp"
|
||||||
|
#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp"
|
||||||
|
#include "../kernel/fmha_kernel_bwd_convert.hpp"
|
||||||
|
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
|
||||||
|
#include "../kernel/fmha_tile_scheduler.hpp"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::fmha::device {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<class Element, class ElementAccumulator, class TileShape, class Fusion, class... Options>
|
||||||
|
class FmhaBwd {
|
||||||
|
public:
|
||||||
|
/// Argument structure: User API
|
||||||
|
struct Arguments {
|
||||||
|
cute::tuple<int, int, int, int, int> problem_size;
|
||||||
|
|
||||||
|
const Element* ptr_Q;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_Q;
|
||||||
|
const Element* ptr_K;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_K;
|
||||||
|
const Element* ptr_V;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_V;
|
||||||
|
|
||||||
|
const Element* ptr_O;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_O;
|
||||||
|
const ElementAccumulator* ptr_LSE;
|
||||||
|
cute::tuple<int, int, _1> stride_LSE;
|
||||||
|
|
||||||
|
const Element* ptr_dO;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_dO;
|
||||||
|
|
||||||
|
Element* ptr_dQ;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_dQ;
|
||||||
|
Element* ptr_dK;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_dK;
|
||||||
|
Element* ptr_dV;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_dV;
|
||||||
|
|
||||||
|
cutlass::KernelHardwareInfo hw_info;
|
||||||
|
};
|
||||||
|
|
||||||
|
using OperationSumOdO = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdSumOdO<Element, ElementAccumulator>>;
|
||||||
|
using OperationConvert = cutlass::device::Universal<cutlass::fmha::kernel::FmhaKernelBwdConvert<Element, ElementAccumulator>>;
|
||||||
|
|
||||||
|
using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized<
|
||||||
|
Element, ElementAccumulator, TileShape,
|
||||||
|
cutlass::fmha::collective::FusionBwdAdapter<Fusion>, Options...>;
|
||||||
|
|
||||||
|
using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV<Element, ElementAccumulator, typename Mainloop::TileShapePV>;
|
||||||
|
|
||||||
|
using Operation = cutlass::device::Universal<
|
||||||
|
cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<
|
||||||
|
Mainloop,
|
||||||
|
Epilogue,
|
||||||
|
cutlass::fmha::kernel::TileSchedulerBwdAdapter<cutlass::fmha::kernel::IndividualTileScheduler>, Options...>>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
OperationSumOdO op_sum_OdO;
|
||||||
|
Operation op;
|
||||||
|
OperationConvert op_convert;
|
||||||
|
ElementAccumulator* dQ_acc;
|
||||||
|
size_t dQ_acc_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
Params params_;
|
||||||
|
|
||||||
|
static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) {
|
||||||
|
auto [B, H, Q, K, D] = args.problem_size;
|
||||||
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
|
auto stride_sum_OdO = make_stride(H*Q, Q, _1{});
|
||||||
|
return typename OperationSumOdO::Arguments {
|
||||||
|
args.problem_size,
|
||||||
|
args.ptr_O, args.stride_O,
|
||||||
|
args.ptr_dO, args.stride_dO,
|
||||||
|
dest, stride_sum_OdO
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) {
|
||||||
|
auto [B, H, Q, K, D] = args.problem_size;
|
||||||
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
|
auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{});
|
||||||
|
return typename OperationConvert::Arguments {
|
||||||
|
args.problem_size,
|
||||||
|
src, stride_src_dQ,
|
||||||
|
nullptr, stride_src_dQ,
|
||||||
|
nullptr, stride_src_dQ,
|
||||||
|
args.ptr_dQ, args.stride_dQ,
|
||||||
|
nullptr, args.stride_dK,
|
||||||
|
nullptr, args.stride_dV
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static typename Operation::Arguments to_bwd_arguments(
|
||||||
|
Arguments const& args,
|
||||||
|
ElementAccumulator* sum_OdO = nullptr, cute::tuple<int, int, _1> const& stride_sum_OdO = {},
|
||||||
|
ElementAccumulator* dQ_acc = nullptr, cute::tuple<int, int, int, _1> const& stride_dQ = {}
|
||||||
|
) {
|
||||||
|
return typename Operation::Arguments{
|
||||||
|
args.problem_size,
|
||||||
|
{ args.ptr_Q, args.stride_Q,
|
||||||
|
args.ptr_K, args.stride_K,
|
||||||
|
args.ptr_V, args.stride_V,
|
||||||
|
args.ptr_dO, args.stride_dO,
|
||||||
|
args.ptr_LSE, args.stride_LSE,
|
||||||
|
sum_OdO, stride_sum_OdO,
|
||||||
|
dQ_acc, stride_dQ },
|
||||||
|
{ args.ptr_dK, args.stride_dK,
|
||||||
|
args.ptr_dV, args.stride_dV },
|
||||||
|
args.hw_info
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Determines whether the GEMM can execute the given problem.
|
||||||
|
static Status
|
||||||
|
can_implement(Arguments const& args) {
|
||||||
|
Status status = Status::kSuccess;
|
||||||
|
|
||||||
|
status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args));
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
status = OperationConvert::can_implement(to_convert_arguments(args));
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
status = Operation::can_implement(to_bwd_arguments(args));
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the workspace size
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(Arguments const& args) {
|
||||||
|
auto [B, H, Q, K, D] = args.problem_size;
|
||||||
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
|
size_t workspace_bytes = 0;
|
||||||
|
// OdO vector
|
||||||
|
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
|
||||||
|
// FP32 versions of outputs that are churned (start off with Q only)
|
||||||
|
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
|
||||||
|
return workspace_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initializes state from arguments.
|
||||||
|
Status
|
||||||
|
initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) {
|
||||||
|
CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ="
|
||||||
|
<< workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null"));
|
||||||
|
|
||||||
|
auto [B, H, Q, K, D] = args.problem_size;
|
||||||
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
|
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_sum_OdO);
|
||||||
|
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
|
||||||
|
params_.dQ_acc = dQ_acc;
|
||||||
|
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
|
||||||
|
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO);
|
||||||
|
auto args_convert = to_convert_arguments(args, dQ_acc);
|
||||||
|
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
|
||||||
|
params_.op_convert.initialize(args_convert, nullptr, stream);
|
||||||
|
auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ);
|
||||||
|
params_.op.initialize(args_bwd, nullptr, stream);
|
||||||
|
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initializes state from arguments.
|
||||||
|
Status
|
||||||
|
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||||
|
CUTLASS_TRACE_HOST("Universal::initialize() - workspace "
|
||||||
|
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||||
|
|
||||||
|
auto [B, H, Q, K, D] = args.problem_size;
|
||||||
|
D = cutlass::round_up(D, 8); // Alignment
|
||||||
|
Q = cutlass::round_up(Q, 8); // Alignment
|
||||||
|
char* workspace_chr = reinterpret_cast<char*>(workspace);
|
||||||
|
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||||
|
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
|
||||||
|
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
|
||||||
|
return initialize_split(args, dQ_acc, sum_OdO, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||||
|
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||||
|
static Status
|
||||||
|
run(Params& params, cudaStream_t stream = nullptr) {
|
||||||
|
CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()");
|
||||||
|
|
||||||
|
Status result = Status::kSuccess;
|
||||||
|
result = params.op_sum_OdO.run(stream);
|
||||||
|
if (result != Status::kSuccess) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream);
|
||||||
|
if (cuda_result != cudaSuccess) {
|
||||||
|
return Status::kErrorInternal;
|
||||||
|
}
|
||||||
|
result = params.op.run(stream);
|
||||||
|
if (result != Status::kSuccess) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
result = params.op_convert.run(stream);
|
||||||
|
if (result != Status::kSuccess) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||||
|
Status
|
||||||
|
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||||
|
Status status = initialize(args, workspace, stream);
|
||||||
|
if (Status::kSuccess == status) {
|
||||||
|
status = run(params_, stream);
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||||
|
Status
|
||||||
|
run(cudaStream_t stream = nullptr) {
|
||||||
|
return run(params_, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::device
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
158
examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp
Normal file
158
examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "../collective/fmha_collective_tma.hpp"
|
||||||
|
#include "../collective/fmha_collective_tma_warpspecialized.hpp"
|
||||||
|
#include "../collective/fmha_epilogue.hpp"
|
||||||
|
#include "../kernel/fmha_kernel_tma.hpp"
|
||||||
|
#include "../kernel/fmha_kernel_tma_warpspecialized.hpp"
|
||||||
|
#include "../kernel/fmha_options.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::kernel {
|
||||||
|
|
||||||
|
template<
|
||||||
|
class Element_,
|
||||||
|
class ElementAccumulatorQK_,
|
||||||
|
class ElementAccumulatorPV_,
|
||||||
|
class TileShape_, // BlockQO, BlockKV, BlockHead
|
||||||
|
class LayoutQ_,
|
||||||
|
class LayoutK_,
|
||||||
|
class LayoutV_,
|
||||||
|
class Fusion,
|
||||||
|
class DispatchPolicy,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaBuilder;
|
||||||
|
|
||||||
|
template<
|
||||||
|
class Element,
|
||||||
|
class ElementAccumulator,
|
||||||
|
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||||
|
class Fusion,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaBuilder<
|
||||||
|
Element,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShape,
|
||||||
|
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||||
|
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||||
|
cute::tuple<int, _1, cute::tuple<int, int>>,
|
||||||
|
Fusion,
|
||||||
|
cutlass::gemm::KernelTma,
|
||||||
|
Options...
|
||||||
|
> {
|
||||||
|
|
||||||
|
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma<Element, ElementAccumulator, TileShape, Fusion, Options...>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
|
||||||
|
Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>;
|
||||||
|
|
||||||
|
using Kernel = cutlass::fmha::kernel::FmhaKernelTma<CollectiveMainloop, CollectiveEpilogue, Options...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<
|
||||||
|
class Element,
|
||||||
|
class ElementAccumulatorQK,
|
||||||
|
class ElementAccumulatorPV,
|
||||||
|
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||||
|
class LayoutQ,
|
||||||
|
class LayoutK,
|
||||||
|
class LayoutV,
|
||||||
|
class Fusion,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaBuilder<
|
||||||
|
Element,
|
||||||
|
ElementAccumulatorQK,
|
||||||
|
ElementAccumulatorPV,
|
||||||
|
TileShape,
|
||||||
|
LayoutQ,
|
||||||
|
LayoutK,
|
||||||
|
LayoutV,
|
||||||
|
Fusion,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||||
|
Options...
|
||||||
|
> {
|
||||||
|
|
||||||
|
using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized<
|
||||||
|
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||||
|
TileShape, LayoutQ, LayoutK, LayoutV,
|
||||||
|
Fusion, Options...>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue<
|
||||||
|
Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>;
|
||||||
|
|
||||||
|
static constexpr bool kIsPersistent = find_option_t<Tag::kIsPersistent, false_type, Options...>::value;
|
||||||
|
using TileScheduler = std::conditional_t<kIsPersistent, cutlass::fmha::kernel::PersistentTileScheduler, cutlass::fmha::kernel::IndividualTileScheduler>;
|
||||||
|
|
||||||
|
using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized<CollectiveMainloop, CollectiveEpilogue, TileScheduler, Options...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<
|
||||||
|
class Element,
|
||||||
|
class ElementAccumulatorQK,
|
||||||
|
class ElementAccumulatorPV,
|
||||||
|
class TileShape, // BlockQO, BlockKV, BlockHead
|
||||||
|
class LayoutQ,
|
||||||
|
class LayoutK,
|
||||||
|
class LayoutV,
|
||||||
|
class Fusion,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaBuilder<
|
||||||
|
Element,
|
||||||
|
ElementAccumulatorQK,
|
||||||
|
ElementAccumulatorPV,
|
||||||
|
TileShape,
|
||||||
|
LayoutQ,
|
||||||
|
LayoutK,
|
||||||
|
LayoutV,
|
||||||
|
Fusion,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpong,
|
||||||
|
Options...
|
||||||
|
> {
|
||||||
|
using Kernel = typename FmhaBuilder<
|
||||||
|
Element, ElementAccumulatorQK, ElementAccumulatorPV,
|
||||||
|
TileShape,
|
||||||
|
LayoutQ, LayoutK, LayoutV,
|
||||||
|
Fusion,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||||
|
Options...,
|
||||||
|
Option<Tag::kIsPersistent, true_type>,
|
||||||
|
Option<Tag::kLoadsQSeparately, true_type>
|
||||||
|
>::Kernel;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::kernel
|
||||||
143
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
143
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cute/layout.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::kernel {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template<class Element, class ElementAccumulator>
|
||||||
|
struct FmhaKernelBwdConvert {
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
tuple<int, int, int, int, int> problem_size;
|
||||||
|
|
||||||
|
const ElementAccumulator* ptr_src_dQ;
|
||||||
|
tuple<int, int, int, _1> stride_src_dQ;
|
||||||
|
const ElementAccumulator* ptr_src_dK;
|
||||||
|
tuple<int, int, int, _1> stride_src_dK;
|
||||||
|
const ElementAccumulator* ptr_src_dV;
|
||||||
|
tuple<int, int, int, _1> stride_src_dV;
|
||||||
|
|
||||||
|
Element* ptr_dest_dQ;
|
||||||
|
tuple<int, int, int, _1> stride_dest_dQ;
|
||||||
|
Element* ptr_dest_dK;
|
||||||
|
tuple<int, int, int, _1> stride_dest_dK;
|
||||||
|
Element* ptr_dest_dV;
|
||||||
|
tuple<int, int, int, _1> stride_dest_dV;
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
static constexpr int SharedStorageSize = 0;
|
||||||
|
|
||||||
|
static const int MinBlocksPerMultiprocessor = 1;
|
||||||
|
static const int MaxThreadsPerBlock = 128;
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
|
||||||
|
static const int kBlockSeq = 8;
|
||||||
|
|
||||||
|
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||||
|
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const int kNumThreadsD = 16;
|
||||||
|
static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD;
|
||||||
|
static const int kElementsPerLoad = 4;
|
||||||
|
|
||||||
|
static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq;
|
||||||
|
|
||||||
|
static bool can_implement(Arguments const& args) {
|
||||||
|
return get<4>(args.problem_size) % kElementsPerLoad == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
|
dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq));
|
||||||
|
return grid;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_block_shape() {
|
||||||
|
dim3 block(kNumThreadsD, kNumThreadsSeq, 1);
|
||||||
|
return block;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class StrideSrc, class StrideDest>
|
||||||
|
CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) {
|
||||||
|
auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y;
|
||||||
|
auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y;
|
||||||
|
|
||||||
|
for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) {
|
||||||
|
int idx_s = idx_s_t + kBlockSeq * blockIdx.z;
|
||||||
|
if (idx_s >= count) continue;
|
||||||
|
auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src);
|
||||||
|
auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest);
|
||||||
|
|
||||||
|
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||||
|
ElementAccumulator value_src[kElementsPerLoad];
|
||||||
|
Element value_dest[kElementsPerLoad];
|
||||||
|
|
||||||
|
using VecSrc = uint_bit_t<sizeof_bits_v<ElementAccumulator> * kElementsPerLoad>;
|
||||||
|
using VecDest = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||||
|
*reinterpret_cast<VecSrc*>(value_src) = *reinterpret_cast<const VecSrc*>(&ptr_src_bhs[idx_d]);
|
||||||
|
|
||||||
|
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||||
|
value_dest[v] = value_src[v];
|
||||||
|
}
|
||||||
|
|
||||||
|
*reinterpret_cast<VecDest*>(&ptr_dest_bhs[idx_d]) = *reinterpret_cast<const VecDest*>(value_dest);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||||
|
if (params.ptr_src_dQ != nullptr) {
|
||||||
|
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size));
|
||||||
|
}
|
||||||
|
if (params.ptr_src_dK != nullptr) {
|
||||||
|
copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size));
|
||||||
|
}
|
||||||
|
if (params.ptr_src_dV != nullptr) {
|
||||||
|
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::kernel
|
||||||
134
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
134
examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cute/layout.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::kernel {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template<class Element, class ElementAccumulator>
|
||||||
|
struct FmhaKernelBwdSumOdO {
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
cute::tuple<int, int, int, int, int> problem_size;
|
||||||
|
|
||||||
|
const Element* ptr_O;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_O;
|
||||||
|
const Element* ptr_dO;
|
||||||
|
cute::tuple<int, int, int, cute::_1> stride_dO;
|
||||||
|
|
||||||
|
ElementAccumulator* ptr_sum_OdO;
|
||||||
|
cute::tuple<int, int, _1> stride_sum_OdO;
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
static constexpr int SharedStorageSize = 0;
|
||||||
|
|
||||||
|
static const int MinBlocksPerMultiprocessor = 1;
|
||||||
|
static const int MaxThreadsPerBlock = 128;
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
|
||||||
|
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||||
|
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const int kBlockQ = 16;
|
||||||
|
|
||||||
|
static const int kNumThreadsD = 8;
|
||||||
|
static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD;
|
||||||
|
static const int kElementsPerLoad = 2;
|
||||||
|
|
||||||
|
static const int kIterationsQ = kBlockQ / kNumThreadsQ;
|
||||||
|
|
||||||
|
static bool can_implement(Arguments const& args) {
|
||||||
|
return get<4>(args.problem_size) % kElementsPerLoad == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
|
dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size));
|
||||||
|
return grid;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_block_shape() {
|
||||||
|
dim3 block(kNumThreadsD, kNumThreadsQ, 1);
|
||||||
|
return block;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||||
|
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O);
|
||||||
|
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO);
|
||||||
|
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) {
|
||||||
|
int idx_q = idx_q_t + kBlockQ * blockIdx.x;
|
||||||
|
if (idx_q >= get<2>(params.problem_size)) continue;
|
||||||
|
ElementAccumulator acc = 0;
|
||||||
|
auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O);
|
||||||
|
auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO);
|
||||||
|
auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO);
|
||||||
|
|
||||||
|
for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) {
|
||||||
|
Element value_O[kElementsPerLoad];
|
||||||
|
Element value_dO[kElementsPerLoad];
|
||||||
|
|
||||||
|
using Vec = uint_bit_t<sizeof_bits_v<Element> * kElementsPerLoad>;
|
||||||
|
*reinterpret_cast<Vec*>(value_O) = *reinterpret_cast<const Vec*>(&ptr_O_bhq[idx_d]);
|
||||||
|
*reinterpret_cast<Vec*>(value_dO) = *reinterpret_cast<const Vec*>(&ptr_dO_bhq[idx_d]);
|
||||||
|
|
||||||
|
for (int v = 0; v < kElementsPerLoad; v++) {
|
||||||
|
acc += value_O[v] * value_dO[v];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i < kNumThreadsD; i *= 2) {
|
||||||
|
acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
*ptr_sum_OdO_bhq = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::kernel
|
||||||
222
examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Normal file
222
examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/pipeline/pipeline.hpp"
|
||||||
|
#include "cutlass/arch/arch.h"
|
||||||
|
|
||||||
|
#include "../kernel/fmha_tile_scheduler.hpp"
|
||||||
|
#include "../kernel/fmha_options.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::kernel {
|
||||||
|
|
||||||
|
template<
|
||||||
|
class CollectiveMainloop,
|
||||||
|
class CollectiveEpilogue,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaKernelTma {
|
||||||
|
|
||||||
|
// Options
|
||||||
|
static constexpr int kBlocksPerSM = find_option_t<Tag::kBlocksPerSM, Int<2>, Options...>::value;
|
||||||
|
|
||||||
|
using Element = typename CollectiveMainloop::Element;
|
||||||
|
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||||
|
|
||||||
|
using TileScheduler = IndividualTileScheduler;
|
||||||
|
|
||||||
|
using StagesQ = typename CollectiveMainloop::StagesQ;
|
||||||
|
using Stages = typename CollectiveMainloop::Stages;
|
||||||
|
|
||||||
|
using TileShape = typename CollectiveMainloop::TileShape;
|
||||||
|
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||||
|
|
||||||
|
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||||
|
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
||||||
|
|
||||||
|
using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ;
|
||||||
|
using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK;
|
||||||
|
|
||||||
|
struct SharedStorage {
|
||||||
|
union {
|
||||||
|
typename CollectiveMainloop::SharedStorage mainloop;
|
||||||
|
typename CollectiveEpilogue::TensorStorage epilogue;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||||
|
using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage;
|
||||||
|
alignas(16) PipelineStorage pipeline_storage;
|
||||||
|
alignas(16) PipelineStorageQ pipeline_storage_q;
|
||||||
|
|
||||||
|
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||||
|
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||||
|
|
||||||
|
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
ProblemShape problem_size;
|
||||||
|
typename CollectiveMainloop::Arguments mainloop;
|
||||||
|
typename CollectiveEpilogue::Arguments epilogue;
|
||||||
|
KernelHardwareInfo hw_info;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ProblemShape problem_size;
|
||||||
|
typename CollectiveMainloop::Params mainloop;
|
||||||
|
typename CollectiveEpilogue::Params epilogue;
|
||||||
|
typename TileScheduler::Params tile_scheduler;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
|
using PipelineState = typename cutlass::PipelineState<MainloopPipeline::Stages>;
|
||||||
|
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||||
|
using PipelineStateQ = typename cutlass::PipelineState<MainloopPipelineQ::Stages>;
|
||||||
|
|
||||||
|
static const int MinBlocksPerMultiprocessor = kBlocksPerSM;
|
||||||
|
static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
|
||||||
|
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||||
|
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool can_implement(Arguments const& args) {
|
||||||
|
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
|
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_block_shape() {
|
||||||
|
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||||
|
return block;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||||
|
return Params{
|
||||||
|
args.problem_size,
|
||||||
|
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
|
||||||
|
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
|
||||||
|
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||||
|
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||||
|
|
||||||
|
// Shared memory.
|
||||||
|
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||||
|
|
||||||
|
int thread_idx = int(threadIdx.x);
|
||||||
|
|
||||||
|
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||||
|
|
||||||
|
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||||
|
int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup;
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|
||||||
|
// Issue Tma Descriptor Prefetch from a single thread
|
||||||
|
if ((warp_idx == 0) && lane_predicate) {
|
||||||
|
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PipelineParamsQ pipeline_params_q;
|
||||||
|
pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q
|
||||||
|
pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer;
|
||||||
|
pipeline_params_q.is_leader = warp_group_thread_idx == 0;
|
||||||
|
pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||||
|
|
||||||
|
PipelineParams pipeline_params;
|
||||||
|
pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV
|
||||||
|
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
|
||||||
|
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||||
|
pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||||
|
|
||||||
|
MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{});
|
||||||
|
MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{});
|
||||||
|
|
||||||
|
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||||
|
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||||
|
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer;
|
||||||
|
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||||
|
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||||
|
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||||
|
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||||
|
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
|
||||||
|
|
||||||
|
// State variables used for iterating the circular buffer
|
||||||
|
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
|
||||||
|
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
|
||||||
|
PipelineState smem_pipe_read;
|
||||||
|
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||||
|
|
||||||
|
PipelineStateQ smem_pipe_read_q;
|
||||||
|
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||||
|
|
||||||
|
// We need this to guarantee that the Pipeline init is visible
|
||||||
|
// To all producers and consumer blocks in the Cluster
|
||||||
|
// and to finish smem init
|
||||||
|
if constexpr (size(ClusterShape{}) > 1) {
|
||||||
|
cute::cluster_arrive_relaxed();
|
||||||
|
cute::cluster_wait();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
|
|
||||||
|
CollectiveMainloop collective_mainloop;
|
||||||
|
auto result = collective_mainloop.compute(
|
||||||
|
block_rank_in_cluster,
|
||||||
|
blk_coord, params.mainloop, params.problem_size,
|
||||||
|
pipeline, smem_pipe_read, smem_pipe_write,
|
||||||
|
pipeline_q, smem_pipe_read_q, smem_pipe_write_q,
|
||||||
|
storage.mainloop
|
||||||
|
);
|
||||||
|
|
||||||
|
CollectiveEpilogue epilogue;
|
||||||
|
epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord,
|
||||||
|
result, typename CollectiveMainloop::TiledMmaPV{},
|
||||||
|
params.problem_size, params.epilogue,
|
||||||
|
epi_load_pipeline, storage.epilogue);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::kernel
|
||||||
@ -0,0 +1,418 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/arch/reg_reconfig.h"
|
||||||
|
#include "cutlass/pipeline/pipeline.hpp"
|
||||||
|
#include "cutlass/arch/arch.h"
|
||||||
|
|
||||||
|
#include "../kernel/fmha_options.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::kernel {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template<
|
||||||
|
class CollectiveMainloop,
|
||||||
|
class CollectiveEpilogue,
|
||||||
|
class TileScheduler,
|
||||||
|
class... Options
|
||||||
|
>
|
||||||
|
struct FmhaKernelTmaWarpSpecialized {
|
||||||
|
|
||||||
|
// Options
|
||||||
|
static constexpr bool kIsEpilogueLocked = find_option_t<Tag::kIsEpilogueLocked, false_type, Options...>::value;
|
||||||
|
static constexpr bool kLoadsQSeparately = find_option_t<Tag::kLoadsQSeparately, false_type, Options...>::value;
|
||||||
|
|
||||||
|
|
||||||
|
static const int NumLoadWarpGroups = 1;
|
||||||
|
static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups;
|
||||||
|
|
||||||
|
using TileShape = typename CollectiveMainloop::TileShape;
|
||||||
|
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
||||||
|
|
||||||
|
using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ;
|
||||||
|
using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline;
|
||||||
|
using MainloopPipelineReducer = cutlass::PipelineAsync<2>;
|
||||||
|
|
||||||
|
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||||
|
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<
|
||||||
|
StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||||
|
|
||||||
|
struct TensorStorageStruct {
|
||||||
|
typename CollectiveMainloop::SharedStorage mainloop;
|
||||||
|
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
|
||||||
|
};
|
||||||
|
union TensorStorageUnion {
|
||||||
|
typename CollectiveMainloop::SharedStorage mainloop;
|
||||||
|
typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups];
|
||||||
|
};
|
||||||
|
using TensorStorage = std::conditional_t<CollectiveMainloop::kIsPersistent, TensorStorageStruct, TensorStorageUnion>;
|
||||||
|
|
||||||
|
struct SharedStorage {
|
||||||
|
TensorStorage tensors;
|
||||||
|
|
||||||
|
using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage;
|
||||||
|
using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage;
|
||||||
|
using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage;
|
||||||
|
|
||||||
|
alignas(16) PipelineStorageInner pipeline_storage_inner;
|
||||||
|
alignas(16) PipelineStorageOuter pipeline_storage_outer;
|
||||||
|
alignas(16) PipelineStorageReducer pipeline_storage_reducer;
|
||||||
|
|
||||||
|
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
|
||||||
|
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||||
|
|
||||||
|
alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier;
|
||||||
|
|
||||||
|
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||||
|
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||||
|
|
||||||
|
using ProblemShape = cute::tuple<int, int, int, int, int>;
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
ProblemShape problem_size;
|
||||||
|
typename CollectiveMainloop::Arguments mainloop;
|
||||||
|
typename CollectiveEpilogue::Arguments epilogue;
|
||||||
|
KernelHardwareInfo hw_info;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ProblemShape problem_size;
|
||||||
|
typename CollectiveMainloop::Params mainloop;
|
||||||
|
typename CollectiveEpilogue::Params epilogue;
|
||||||
|
typename TileScheduler::Params tile_scheduler;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PipelineParamsInner = typename MainloopPipelineInner::Params;
|
||||||
|
using PipelineStateInner = typename cutlass::PipelineState<MainloopPipelineInner::Stages>;
|
||||||
|
using PipelineParamsOuter = typename MainloopPipelineOuter::Params;
|
||||||
|
using PipelineStateOuter = typename cutlass::PipelineState<MainloopPipelineOuter::Stages>;
|
||||||
|
using PipelineParamsReducer = typename MainloopPipelineReducer::Params;
|
||||||
|
using PipelineStateReducer = typename cutlass::PipelineState<MainloopPipelineReducer::Stages>;
|
||||||
|
|
||||||
|
static const int MinBlocksPerMultiprocessor = 1;
|
||||||
|
static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup;
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
|
||||||
|
static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8;
|
||||||
|
static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup;
|
||||||
|
static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8;
|
||||||
|
|
||||||
|
static size_t get_workspace_size(Arguments const& args) { return 0; }
|
||||||
|
static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool can_implement(Arguments const& args) {
|
||||||
|
return CollectiveMainloop::can_implement(args.problem_size, args.mainloop);
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
|
return TileScheduler::get_grid_shape(params.tile_scheduler);
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_block_shape() {
|
||||||
|
dim3 block(MaxThreadsPerBlock, 1, 1);
|
||||||
|
return block;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||||
|
return Params{
|
||||||
|
args.problem_size,
|
||||||
|
CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace),
|
||||||
|
CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace),
|
||||||
|
TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) {
|
||||||
|
|
||||||
|
enum class WarpGroupRole {
|
||||||
|
Producer = 0,
|
||||||
|
Consumer0 = 1,
|
||||||
|
Consumer1 = 2,
|
||||||
|
Consumer2 = 3,
|
||||||
|
Consumer3 = 4,
|
||||||
|
};
|
||||||
|
enum class ProducerWarpRole {
|
||||||
|
LoadKV = 1,
|
||||||
|
Reducer = 0,
|
||||||
|
MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it)
|
||||||
|
MainloopEpilogue = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV;
|
||||||
|
|
||||||
|
TileScheduler tile_scheduler{params.tile_scheduler};
|
||||||
|
|
||||||
|
// Shared memory.
|
||||||
|
auto& storage = *reinterpret_cast<SharedStorage*>(smem);
|
||||||
|
|
||||||
|
int lane_idx = cutlass::canonical_lane_idx();
|
||||||
|
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||||
|
int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup;
|
||||||
|
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||||
|
auto warp_group_role = WarpGroupRole(warp_group_idx);
|
||||||
|
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||||
|
int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0;
|
||||||
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||||
|
|
||||||
|
// Issue Tma Descriptor Prefetch from a single thread
|
||||||
|
if ((warp_idx == 0) && lane_predicate) {
|
||||||
|
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||||
|
}
|
||||||
|
|
||||||
|
PipelineParamsOuter pipeline_params_outer;
|
||||||
|
pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes;
|
||||||
|
pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ);
|
||||||
|
pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup;
|
||||||
|
|
||||||
|
PipelineParamsInner pipeline_params_inner;
|
||||||
|
pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes;
|
||||||
|
pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV);
|
||||||
|
pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
||||||
|
|
||||||
|
PipelineParamsReducer pipeline_params_reducer;
|
||||||
|
pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
||||||
|
pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp;
|
||||||
|
|
||||||
|
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||||
|
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||||
|
|
||||||
|
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) {
|
||||||
|
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||||
|
}
|
||||||
|
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) {
|
||||||
|
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer;
|
||||||
|
}
|
||||||
|
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) {
|
||||||
|
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer;
|
||||||
|
}
|
||||||
|
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) {
|
||||||
|
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer;
|
||||||
|
}
|
||||||
|
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||||
|
warp_group_role == WarpGroupRole::Consumer1 ||
|
||||||
|
warp_group_role == WarpGroupRole::Consumer2 ||
|
||||||
|
warp_group_role == WarpGroupRole::Consumer3
|
||||||
|
) {
|
||||||
|
pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer;
|
||||||
|
pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer;
|
||||||
|
pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer;
|
||||||
|
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||||
|
}
|
||||||
|
|
||||||
|
MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{});
|
||||||
|
MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{});
|
||||||
|
MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer);
|
||||||
|
|
||||||
|
// State variables used for iterating the circular buffer
|
||||||
|
// smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA
|
||||||
|
// smem_pipe_write is used by the producer of SMEM data - i.e TMA
|
||||||
|
PipelineStateInner smem_pipe_read_inner;
|
||||||
|
PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state<MainloopPipelineInner>();
|
||||||
|
|
||||||
|
PipelineStateOuter smem_pipe_read_outer;
|
||||||
|
PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state<MainloopPipelineOuter>();
|
||||||
|
|
||||||
|
PipelineStateReducer smem_pipe_read_reducer;
|
||||||
|
PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state<MainloopPipelineReducer>();
|
||||||
|
|
||||||
|
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||||
|
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||||
|
params_math_wg_order_barrier.group_id = consumer_warp_group_idx;
|
||||||
|
params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||||
|
MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier);
|
||||||
|
|
||||||
|
// Epilogue Load pipeline
|
||||||
|
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||||
|
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||||
|
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||||
|
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||||
|
EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params);
|
||||||
|
|
||||||
|
if constexpr (kLoadsQSeparately) {
|
||||||
|
if ((warp_idx == 0) && lane_predicate) {
|
||||||
|
storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp);
|
||||||
|
}
|
||||||
|
cutlass::arch::fence_barrier_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need this to guarantee that the Pipeline init is visible
|
||||||
|
// To all producers and consumer blocks in the Cluster
|
||||||
|
// and to finish smem init
|
||||||
|
if constexpr (size(ClusterShape{}) > 1) {
|
||||||
|
cute::cluster_arrive_relaxed();
|
||||||
|
cute::cluster_wait();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
CollectiveMainloop collective_mainloop;
|
||||||
|
|
||||||
|
if (warp_group_role == WarpGroupRole::Producer) {
|
||||||
|
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||||
|
if (producer_warp_role == ProducerWarpRole::LoadKV) {
|
||||||
|
bool do_barrier = kLoadsQSeparately;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
|
collective_mainloop.template load_kv_maybe_q<!kLoadsQSeparately>(
|
||||||
|
block_rank_in_cluster,
|
||||||
|
blk_coord, params.mainloop, params.problem_size,
|
||||||
|
pipeline_inner, smem_pipe_write_inner,
|
||||||
|
pipeline_outer, smem_pipe_write_outer,
|
||||||
|
storage.tensors.mainloop,
|
||||||
|
storage.load_warp_barrier, do_barrier
|
||||||
|
);
|
||||||
|
do_barrier = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) {
|
||||||
|
bool do_barrier = true;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
|
collective_mainloop.load_maybe_q(
|
||||||
|
blk_coord, params.mainloop, params.problem_size,
|
||||||
|
pipeline_outer, smem_pipe_write_outer,
|
||||||
|
storage.tensors.mainloop,
|
||||||
|
storage.load_warp_barrier, do_barrier
|
||||||
|
);
|
||||||
|
do_barrier = false;
|
||||||
|
}
|
||||||
|
} else if (producer_warp_role == ProducerWarpRole::Reducer) {
|
||||||
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
|
collective_mainloop.reduce(
|
||||||
|
blk_coord, params.mainloop, params.problem_size,
|
||||||
|
pipeline_reducer, smem_pipe_read_reducer,
|
||||||
|
storage.tensors.mainloop
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (
|
||||||
|
warp_group_role == WarpGroupRole::Consumer0 ||
|
||||||
|
warp_group_role == WarpGroupRole::Consumer1 ||
|
||||||
|
warp_group_role == WarpGroupRole::Consumer2 ||
|
||||||
|
warp_group_role == WarpGroupRole::Consumer3
|
||||||
|
) {
|
||||||
|
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
|
auto wg_coord = blk_coord;
|
||||||
|
|
||||||
|
constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads;
|
||||||
|
|
||||||
|
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||||
|
smem_pipe_read_outer.advance(0 * kOuterLoads);
|
||||||
|
}
|
||||||
|
else if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||||
|
smem_pipe_read_outer.advance(1 * kOuterLoads);
|
||||||
|
}
|
||||||
|
else if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||||
|
smem_pipe_read_outer.advance(2 * kOuterLoads);
|
||||||
|
}
|
||||||
|
else if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||||
|
smem_pipe_read_outer.advance(3 * kOuterLoads);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1;
|
||||||
|
auto& wg_block = get<wg_dim>(wg_coord);
|
||||||
|
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||||
|
wg_block = NumMmaWarpGroups * wg_block + 0;
|
||||||
|
}
|
||||||
|
else if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||||
|
wg_block = NumMmaWarpGroups * wg_block + 1;
|
||||||
|
}
|
||||||
|
else if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||||
|
wg_block = NumMmaWarpGroups * wg_block + 2;
|
||||||
|
}
|
||||||
|
else if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||||
|
wg_block = NumMmaWarpGroups * wg_block + 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = collective_mainloop.compute(
|
||||||
|
blk_coord, wg_coord,
|
||||||
|
params.mainloop, params.problem_size,
|
||||||
|
pipeline_inner, smem_pipe_read_inner,
|
||||||
|
pipeline_outer, smem_pipe_read_outer,
|
||||||
|
pipeline_reducer, smem_pipe_write_reducer,
|
||||||
|
storage.tensors.mainloop,
|
||||||
|
math_wg_order_barrier
|
||||||
|
);
|
||||||
|
|
||||||
|
if (warp_group_role == WarpGroupRole::Consumer0) {
|
||||||
|
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0));
|
||||||
|
}
|
||||||
|
if constexpr (NumMmaWarpGroups >= 2) {
|
||||||
|
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||||
|
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (NumMmaWarpGroups >= 3) {
|
||||||
|
if (warp_group_role == WarpGroupRole::Consumer2) {
|
||||||
|
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (NumMmaWarpGroups >= 4) {
|
||||||
|
if (warp_group_role == WarpGroupRole::Consumer3) {
|
||||||
|
smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait();
|
||||||
|
|
||||||
|
CollectiveEpilogue epilogue;
|
||||||
|
epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord,
|
||||||
|
result, typename CollectiveMainloop::TiledMmaPV{},
|
||||||
|
params.problem_size, params.epilogue,
|
||||||
|
epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]);
|
||||||
|
|
||||||
|
if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::kernel
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
* SPDX-License-Identifier: BSD-3-Clause
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without
|
* Redistribution and use in source and binary forms, with or without
|
||||||
@ -28,51 +28,56 @@
|
|||||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*
|
*
|
||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
#include "cutlass/cutlass.h"
|
||||||
#include <cute/numeric/integral_constant.hpp> // cute::true_type
|
|
||||||
|
|
||||||
namespace cute
|
namespace cutlass::fmha::kernel {
|
||||||
{
|
|
||||||
|
|
||||||
template <class T>
|
template<auto kTag, typename Default, typename... Options>
|
||||||
struct ConstantTensor
|
struct find_option;
|
||||||
{
|
|
||||||
template <class... Coords>
|
|
||||||
CUTE_HOST_DEVICE constexpr
|
|
||||||
T const&
|
|
||||||
operator()(Coords const&...) const {
|
|
||||||
return val_;
|
|
||||||
}
|
|
||||||
|
|
||||||
T val_;
|
template<auto kTag, typename Default>
|
||||||
|
struct find_option<kTag, Default> {
|
||||||
|
using option_value = Default;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TrivialPredTensor
|
template<auto kTag, typename Default, typename Option, typename... Options>
|
||||||
{
|
struct find_option<kTag, Default, Option, Options...> :
|
||||||
template <class... Coords>
|
std::conditional_t<
|
||||||
CUTE_HOST_DEVICE constexpr
|
Option::tag == kTag,
|
||||||
true_type
|
Option,
|
||||||
operator()(Coords const&...) const {
|
find_option<kTag, Default, Options...>
|
||||||
return {};
|
>
|
||||||
}
|
{};
|
||||||
|
|
||||||
|
template<auto kTag, typename Default, typename... Options>
|
||||||
|
using find_option_t = typename find_option<kTag, Default, Options...>::option_value;
|
||||||
|
|
||||||
|
enum class Tag {
|
||||||
|
kIsPersistent,
|
||||||
|
kNumMmaWarpGroups,
|
||||||
|
kLoadsQSeparately,
|
||||||
|
|
||||||
|
kIsMainloopLocked,
|
||||||
|
kIsEpilogueLocked,
|
||||||
|
|
||||||
|
kStagesQ,
|
||||||
|
kStagesKV,
|
||||||
|
|
||||||
|
kEpilogueKind,
|
||||||
|
|
||||||
|
kBlocksPerSM,
|
||||||
|
kClusterM,
|
||||||
|
|
||||||
|
kAccQK
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class Fn>
|
template<auto kTag, class Value>
|
||||||
struct FunctionPredTensor
|
struct Option {
|
||||||
{
|
static constexpr auto tag = kTag;
|
||||||
CUTE_HOST_DEVICE constexpr
|
using option_value = Value;
|
||||||
FunctionPredTensor(Fn const& fn) : fn_(fn) {}
|
|
||||||
|
|
||||||
template <class... Coords>
|
|
||||||
CUTE_HOST_DEVICE constexpr
|
|
||||||
auto
|
|
||||||
operator()(Coords const&... coords) const {
|
|
||||||
return fn_(coords...);
|
|
||||||
}
|
|
||||||
|
|
||||||
Fn const& fn_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace cute
|
} // namespace cutlass::fmha::kernel
|
||||||
204
examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
204
examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/fast_math.h"
|
||||||
|
#include "cutlass/kernel_hardware_info.h"
|
||||||
|
|
||||||
|
namespace cutlass::fmha::kernel {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct IndividualTileScheduler {
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
dim3 grid;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool valid_ = true;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
IndividualTileScheduler(Params const&) {}
|
||||||
|
|
||||||
|
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||||
|
static Params to_underlying_arguments(
|
||||||
|
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||||
|
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||||
|
{
|
||||||
|
using namespace cute;
|
||||||
|
dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size));
|
||||||
|
return Params{ grid };
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
|
return params.grid;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool is_valid() {
|
||||||
|
return valid_;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
auto get_block_coord() {
|
||||||
|
using namespace cute;
|
||||||
|
return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z));
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
IndividualTileScheduler& operator++() {
|
||||||
|
valid_ = false;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct PersistentTileScheduler {
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
int num_blocks;
|
||||||
|
FastDivmod divmod_m_block;
|
||||||
|
FastDivmod divmod_b;
|
||||||
|
FastDivmod divmod_h;
|
||||||
|
|
||||||
|
KernelHardwareInfo hw_info;
|
||||||
|
};
|
||||||
|
|
||||||
|
int block_idx = 0;
|
||||||
|
Params params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||||
|
|
||||||
|
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||||
|
static Params to_underlying_arguments(
|
||||||
|
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||||
|
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||||
|
{
|
||||||
|
using namespace cute;
|
||||||
|
// Get SM count if needed, otherwise use user supplied SM count
|
||||||
|
int sm_count = hw_info.sm_count;
|
||||||
|
if (sm_count <= 0) {
|
||||||
|
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||||
|
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||||
|
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||||
|
hw_info.sm_count = sm_count;
|
||||||
|
|
||||||
|
int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape));
|
||||||
|
int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size);
|
||||||
|
|
||||||
|
return Params {
|
||||||
|
num_blocks,
|
||||||
|
{ num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) },
|
||||||
|
hw_info
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
|
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||||
|
return grid;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool is_valid() {
|
||||||
|
return block_idx < params.num_blocks;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
auto get_block_coord() {
|
||||||
|
using namespace cute;
|
||||||
|
int block_decode = block_idx;
|
||||||
|
int m_block, bidb, bidh;
|
||||||
|
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||||
|
params.divmod_b(block_decode, bidb, block_decode);
|
||||||
|
params.divmod_h(block_decode, bidh, block_decode);
|
||||||
|
return make_coord(m_block, _0{}, make_coord(bidb, bidh));
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
PersistentTileScheduler& operator++() {
|
||||||
|
block_idx += gridDim.x;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename Base>
|
||||||
|
struct TileSchedulerBwdAdapter {
|
||||||
|
|
||||||
|
using Params = typename Base::Params;
|
||||||
|
|
||||||
|
Base base_;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
TileSchedulerBwdAdapter(Params const& params) : base_(params) {}
|
||||||
|
|
||||||
|
template<class ProblemSize, class ClusterShape, class TileShape>
|
||||||
|
static Params to_underlying_arguments(
|
||||||
|
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
|
||||||
|
ClusterShape const& cluster_shape, TileShape const& tile_shape)
|
||||||
|
{
|
||||||
|
using namespace cute;
|
||||||
|
return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
static dim3 get_grid_shape(Params const& params) {
|
||||||
|
return Base::get_grid_shape(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
bool is_valid() {
|
||||||
|
return base_.is_valid();
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
auto get_block_coord() {
|
||||||
|
using namespace cute;
|
||||||
|
return select<1,0,2>(base_.get_block_coord());
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
TileSchedulerBwdAdapter& operator++() {
|
||||||
|
++base_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace cutlass::fmha::kernel
|
||||||
357
examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp
Normal file
357
examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp
Normal file
@ -0,0 +1,357 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ, class TensorK, class TensorV,
|
||||||
|
class TensorO, class TensorLSE, class TensorDO,
|
||||||
|
class TensorDQ, /* class TensorDK, class TensorDV, */
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void __global__ fmha_bwd_reference_dQ_kernel(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
|
TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
using Element = typename TensorO::value_type;
|
||||||
|
using ElementAccumulator = typename TensorLSE::value_type;
|
||||||
|
|
||||||
|
extern __shared__ char mS_mem[];
|
||||||
|
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||||
|
|
||||||
|
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||||
|
|
||||||
|
for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) {
|
||||||
|
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) {
|
||||||
|
|
||||||
|
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||||
|
ElementAccumulator acc_qk = 0;
|
||||||
|
ElementAccumulator acc_dov = 0;
|
||||||
|
ElementAccumulator acc_doo = 0;
|
||||||
|
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||||
|
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||||
|
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||||
|
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto id = make_identity_tensor(make_shape(1, 1));
|
||||||
|
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||||
|
frag(0) = acc_qk;
|
||||||
|
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||||
|
acc_qk = frag(0);
|
||||||
|
|
||||||
|
mS[idx_K] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) {
|
||||||
|
ElementAccumulator acc = 0;
|
||||||
|
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||||
|
acc += mS[idx_K] * mK(idx_K, idx_D, idx_L);
|
||||||
|
}
|
||||||
|
mDQ(idx_Q, idx_D, idx_L) = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ, class TensorK, class TensorV,
|
||||||
|
class TensorO, class TensorLSE, class TensorDO,
|
||||||
|
/* class TensorDQ, */ class TensorDK, /* class TensorDV, */
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void __global__ fmha_bwd_reference_dK_kernel(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
|
/* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
using Element = typename TensorO::value_type;
|
||||||
|
using ElementAccumulator = typename TensorLSE::value_type;
|
||||||
|
|
||||||
|
extern __shared__ char mS_mem[];
|
||||||
|
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||||
|
|
||||||
|
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||||
|
|
||||||
|
for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) {
|
||||||
|
for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) {
|
||||||
|
|
||||||
|
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||||
|
ElementAccumulator acc_qk = 0;
|
||||||
|
ElementAccumulator acc_dov = 0;
|
||||||
|
ElementAccumulator acc_doo = 0;
|
||||||
|
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||||
|
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||||
|
acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L);
|
||||||
|
acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto id = make_identity_tensor(make_shape(1, 1));
|
||||||
|
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||||
|
frag(0) = acc_qk;
|
||||||
|
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||||
|
acc_qk = frag(0);
|
||||||
|
|
||||||
|
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo));
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) {
|
||||||
|
ElementAccumulator acc = 0;
|
||||||
|
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||||
|
acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L);
|
||||||
|
}
|
||||||
|
mDK(idx_K, idx_D, idx_L) = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ, class TensorK, class TensorV,
|
||||||
|
class TensorO, class TensorLSE, class TensorDO,
|
||||||
|
/* class TensorDQ, class TensorDK, */ class TensorDV,
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void __global__ fmha_bwd_reference_dV_kernel(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
|
/* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV,
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
using Element = typename TensorO::value_type;
|
||||||
|
using ElementAccumulator = typename TensorLSE::value_type;
|
||||||
|
|
||||||
|
extern __shared__ char mS_mem[];
|
||||||
|
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||||
|
|
||||||
|
Element softmax_scale = static_cast<Element>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||||
|
|
||||||
|
for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) {
|
||||||
|
for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) {
|
||||||
|
|
||||||
|
for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) {
|
||||||
|
ElementAccumulator acc_qk = 0;
|
||||||
|
for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) {
|
||||||
|
acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto id = make_identity_tensor(make_shape(1, 1));
|
||||||
|
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||||
|
frag(0) = acc_qk;
|
||||||
|
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||||
|
acc_qk = frag(0);
|
||||||
|
|
||||||
|
mS[idx_Q] = static_cast<Element>(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)));
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) {
|
||||||
|
ElementAccumulator acc = 0;
|
||||||
|
for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) {
|
||||||
|
acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L);
|
||||||
|
}
|
||||||
|
mDV(idx_K, idx_D, idx_L) = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ, class TensorK, class TensorV,
|
||||||
|
class TensorO, class TensorLSE, class TensorDO,
|
||||||
|
/**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void fmha_bwd_reference_dQ(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
|
/**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
dim3 grid(size<0>(mDQ), size<2>(mDQ), 1);
|
||||||
|
dim3 block(256);
|
||||||
|
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
|
||||||
|
|
||||||
|
if (shared_mem >= (48 << 10)) {
|
||||||
|
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||||
|
auto result = cudaFuncSetAttribute(
|
||||||
|
fmha_bwd_reference_dQ_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDQ, Fusion>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
shared_mem);
|
||||||
|
if (cudaSuccess != result) {
|
||||||
|
result = cudaGetLastError(); // to clear the error bit
|
||||||
|
CUTLASS_TRACE_HOST(
|
||||||
|
" cudaFuncSetAttribute() returned error: "
|
||||||
|
<< cudaGetErrorString(result));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmha_bwd_reference_dQ_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ, class TensorK, class TensorV,
|
||||||
|
class TensorO, class TensorLSE, class TensorDO,
|
||||||
|
/** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void fmha_bwd_reference_dK(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
|
/** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
dim3 grid(size<0>(mDK), size<2>(mDK), 1);
|
||||||
|
dim3 block(256);
|
||||||
|
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
|
||||||
|
|
||||||
|
if (shared_mem >= (48 << 10)) {
|
||||||
|
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||||
|
auto result = cudaFuncSetAttribute(
|
||||||
|
fmha_bwd_reference_dK_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDK, Fusion>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
shared_mem);
|
||||||
|
if (cudaSuccess != result) {
|
||||||
|
result = cudaGetLastError(); // to clear the error bit
|
||||||
|
CUTLASS_TRACE_HOST(
|
||||||
|
" cudaFuncSetAttribute() returned error: "
|
||||||
|
<< cudaGetErrorString(result));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmha_bwd_reference_dK_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ, class TensorK, class TensorV,
|
||||||
|
class TensorO, class TensorLSE, class TensorDO,
|
||||||
|
/** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void fmha_bwd_reference_dV(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
|
/** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
dim3 grid(size<0>(mDV), size<2>(mDV), 1);
|
||||||
|
dim3 block(256);
|
||||||
|
int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type);
|
||||||
|
|
||||||
|
if (shared_mem >= (48 << 10)) {
|
||||||
|
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||||
|
auto result = cudaFuncSetAttribute(
|
||||||
|
fmha_bwd_reference_dV_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, TensorDO, TensorDV, Fusion>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
shared_mem);
|
||||||
|
if (cudaSuccess != result) {
|
||||||
|
result = cudaGetLastError(); // to clear the error bit
|
||||||
|
CUTLASS_TRACE_HOST(
|
||||||
|
" cudaFuncSetAttribute() returned error: "
|
||||||
|
<< cudaGetErrorString(result));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmha_bwd_reference_dV_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ, class TensorK, class TensorV,
|
||||||
|
class TensorO, class TensorLSE, class TensorDO,
|
||||||
|
class TensorDQ, class TensorDK, class TensorDV,
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void fmha_bwd_reference(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE, TensorDO mDO,
|
||||||
|
TensorDQ mDQ, TensorDK mDK, TensorDV mDV,
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion);
|
||||||
|
fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion);
|
||||||
|
fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
156
examples/88_hopper_fmha/reference/fmha_reference.hpp
Normal file
156
examples/88_hopper_fmha/reference/fmha_reference.hpp
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ,
|
||||||
|
class TensorK,
|
||||||
|
class TensorV,
|
||||||
|
class TensorO,
|
||||||
|
class TensorLSE,
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void __global__ fmha_reference_kernel(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE,
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
using Element = typename TensorO::value_type;
|
||||||
|
using ElementAccumulator = typename TensorLSE::value_type;
|
||||||
|
|
||||||
|
extern __shared__ char mS_mem[];
|
||||||
|
Element* mS = reinterpret_cast<Element*>(mS_mem);
|
||||||
|
|
||||||
|
ElementAccumulator softmax_scale = static_cast<ElementAccumulator>(1.0 / sqrt(1.0 * size<1>(mO)));
|
||||||
|
|
||||||
|
auto id = make_identity_tensor(make_shape(1, 1));
|
||||||
|
for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) {
|
||||||
|
for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) {
|
||||||
|
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||||
|
ElementAccumulator acc = 0;
|
||||||
|
for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) {
|
||||||
|
acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L);
|
||||||
|
}
|
||||||
|
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
|
||||||
|
frag(0) = acc;
|
||||||
|
fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape);
|
||||||
|
mS[idx_K] = static_cast<Element>(frag(0) * softmax_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
ElementAccumulator maxS = -std::numeric_limits<ElementAccumulator>::infinity();
|
||||||
|
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||||
|
maxS = std::max<ElementAccumulator>(maxS, mS[idx_K]);
|
||||||
|
}
|
||||||
|
if (maxS == -std::numeric_limits<ElementAccumulator>::infinity()) maxS = 0;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) {
|
||||||
|
mS[idx_K] = static_cast<Element>(exp(mS[idx_K] - maxS));
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
ElementAccumulator sum = 0;
|
||||||
|
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||||
|
sum += mS[idx_K];
|
||||||
|
}
|
||||||
|
|
||||||
|
Element scale = static_cast<Element>(1.0 / sum);
|
||||||
|
|
||||||
|
for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) {
|
||||||
|
ElementAccumulator acc = 0;
|
||||||
|
for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) {
|
||||||
|
acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale;
|
||||||
|
}
|
||||||
|
mO(idx_Q, idx_D, idx_L) = static_cast<Element>(acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
mLSE(idx_Q, idx_L) = log(sum) + maxS;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ProblemShape,
|
||||||
|
class TensorQ,
|
||||||
|
class TensorK,
|
||||||
|
class TensorV,
|
||||||
|
class TensorO,
|
||||||
|
class TensorLSE,
|
||||||
|
class Fusion
|
||||||
|
>
|
||||||
|
void fmha_reference(
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
TensorQ mQ, TensorK mK, TensorV mV,
|
||||||
|
TensorO mO, TensorLSE mLSE,
|
||||||
|
Fusion fusion
|
||||||
|
) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
dim3 grid(size<0>(mO), size<2>(mO), 1);
|
||||||
|
dim3 block(256);
|
||||||
|
int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type);
|
||||||
|
|
||||||
|
if (shared_mem >= (48 << 10)) {
|
||||||
|
CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem);
|
||||||
|
auto result = cudaFuncSetAttribute(
|
||||||
|
fmha_reference_kernel<ProblemShape, TensorQ, TensorK, TensorV, TensorO, TensorLSE, Fusion>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
shared_mem);
|
||||||
|
if (cudaSuccess != result) {
|
||||||
|
result = cudaGetLastError(); // to clear the error bit
|
||||||
|
CUTLASS_TRACE_HOST(
|
||||||
|
" cudaFuncSetAttribute() returned error: "
|
||||||
|
<< cudaGetErrorString(result));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmha_reference_kernel<<<grid, block, shared_mem>>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
129
examples/88_hopper_fmha/reference/reference_abs_error.hpp
Normal file
129
examples/88_hopper_fmha/reference/reference_abs_error.hpp
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include "cutlass/util/device_memory.h"
|
||||||
|
|
||||||
|
template<typename Element>
|
||||||
|
__global__ void reference_abs_diff_kernel(
|
||||||
|
Element* data, Element* data_ref, size_t count,
|
||||||
|
double* max_diff, double* sum_diff,
|
||||||
|
bool print_diff
|
||||||
|
) {
|
||||||
|
double thread_max_diff = 0;
|
||||||
|
double thread_sum_diff = 0;
|
||||||
|
|
||||||
|
__shared__ double block_max_diff;
|
||||||
|
__shared__ double block_sum_diff;
|
||||||
|
|
||||||
|
for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) {
|
||||||
|
double diff = fabs(data[i] - data_ref[i]);
|
||||||
|
if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast<long long int>(i), diff, (double)data[i], (double)data_ref[i]);
|
||||||
|
thread_max_diff = fmax(diff, thread_max_diff);
|
||||||
|
thread_sum_diff += diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < blockDim.x; i++) {
|
||||||
|
if (i == threadIdx.x) {
|
||||||
|
if (i == 0) {
|
||||||
|
block_max_diff = thread_max_diff;
|
||||||
|
block_sum_diff = thread_sum_diff;
|
||||||
|
} else {
|
||||||
|
block_max_diff = fmax(block_max_diff, thread_max_diff);
|
||||||
|
block_sum_diff += thread_sum_diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
atomicAdd(sum_diff, block_sum_diff);
|
||||||
|
|
||||||
|
for (;;) {
|
||||||
|
unsigned long long prev = *reinterpret_cast<unsigned long long*>(max_diff);
|
||||||
|
double prev_diff = reinterpret_cast<double const&>(prev);
|
||||||
|
double new_max_diff = fmax(block_max_diff, prev_diff);
|
||||||
|
unsigned long long found = atomicCAS(reinterpret_cast<unsigned long long*>(max_diff), prev, reinterpret_cast<unsigned long long const&>(new_max_diff));
|
||||||
|
if (found == prev) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Element>
|
||||||
|
void reference_abs_diff(
|
||||||
|
cutlass::DeviceAllocation<Element> const& data,
|
||||||
|
cutlass::DeviceAllocation<Element> const& data_ref,
|
||||||
|
double& max_diff, double& mean_diff
|
||||||
|
) {
|
||||||
|
static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1;
|
||||||
|
|
||||||
|
cutlass::DeviceAllocation<double> result;
|
||||||
|
result.reset(2);
|
||||||
|
assert(data.size() == data_ref.size());
|
||||||
|
|
||||||
|
cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double));
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
std::cerr << "Memset failed. Last CUDA error: "
|
||||||
|
<< cudaGetErrorString(err) << std::endl;
|
||||||
|
max_diff = mean_diff = 1e20;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 block(256, 1, 1);
|
||||||
|
dim3 grid(1024, 1, 1);
|
||||||
|
reference_abs_diff_kernel<<<block, grid>>>(
|
||||||
|
data.get(), data_ref.get(), data.size(),
|
||||||
|
result.get(), result.get() + 1, kPrintDiff);
|
||||||
|
|
||||||
|
err = cudaDeviceSynchronize();
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
std::cerr << "Difference kernel failed. Last CUDA error: "
|
||||||
|
<< cudaGetErrorString(err) << std::endl;
|
||||||
|
max_diff = mean_diff = 1e20;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
double result_host[2];
|
||||||
|
err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
std::cerr << "Copy failed. Last CUDA error: "
|
||||||
|
<< cudaGetErrorString(err) << std::endl;
|
||||||
|
max_diff = mean_diff = 1e20;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
max_diff = result_host[0];
|
||||||
|
mean_diff = result_host[1] / static_cast<double>(data.size());
|
||||||
|
}
|
||||||
@ -163,6 +163,7 @@ foreach(EXAMPLE
|
|||||||
82_blackwell_distributed_gemm
|
82_blackwell_distributed_gemm
|
||||||
83_blackwell_sparse_gemm
|
83_blackwell_sparse_gemm
|
||||||
84_blackwell_narrow_precision_sparse_gemm
|
84_blackwell_narrow_precision_sparse_gemm
|
||||||
|
88_hopper_fmha
|
||||||
)
|
)
|
||||||
|
|
||||||
add_subdirectory(${EXAMPLE})
|
add_subdirectory(${EXAMPLE})
|
||||||
|
|||||||
@ -55,3 +55,7 @@ cutlass_example_add_executable(
|
|||||||
tiled_copy.cu
|
tiled_copy.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cutlass_example_add_executable(
|
||||||
|
cute_tutorial_tiled_copy_if
|
||||||
|
tiled_copy_if.cu
|
||||||
|
)
|
||||||
|
|||||||
297
examples/cute/tutorial/tiled_copy_if.cu
Normal file
297
examples/cute/tutorial/tiled_copy_if.cu
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#include <thrust/host_vector.h>
|
||||||
|
#include <thrust/device_vector.h>
|
||||||
|
|
||||||
|
#include <cute/tensor.hpp>
|
||||||
|
|
||||||
|
#include "cutlass/util/print_error.hpp"
|
||||||
|
#include "cutlass/util/GPU_Clock.hpp"
|
||||||
|
#include "cutlass/util/helper_cuda.hpp"
|
||||||
|
|
||||||
|
// This example extends `tiled_copy` using predicate tensors to guard memory accesses performed
|
||||||
|
// by `cute::copy_if()`. This enables tensors to have shapes that are not integer multiples of
|
||||||
|
// block sizes.
|
||||||
|
//
|
||||||
|
// This is accomplished by instantiating a tensor of coordinates which correspond to tensor elements
|
||||||
|
// to be accessed and then computing a predicate tensor which masks accesses. The example demonstrates
|
||||||
|
// how constructing of an identity tensor containing coordinates and a predicate tensor containing
|
||||||
|
// mask bits can be implemented using the same CuTe operations used to tile the tensors in
|
||||||
|
// Global Memory.
|
||||||
|
//
|
||||||
|
// This example implements two variants:
|
||||||
|
// - copy_if_kernel() uses `cute::local_partition()` to construct each thread's slice
|
||||||
|
// - copy_if_kernel_vectorized() uses `make_tiled_copy() to implement vectorized memory accesses.
|
||||||
|
//
|
||||||
|
// The tensor shapes and strides must be divisible by the shape of the vector access.
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Simple copy kernel.
|
||||||
|
//
|
||||||
|
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
|
||||||
|
template <class TensorS, class TensorD, class BlockShape, class ThreadLayout>
|
||||||
|
__global__ void copy_if_kernel(TensorS S, TensorD D, BlockShape block_shape, ThreadLayout)
|
||||||
|
{
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
|
||||||
|
auto shape_S = shape(S);
|
||||||
|
Tensor C = make_identity_tensor(shape_S);
|
||||||
|
// Construct a predicate tensor which compares the coordinates with the original shape
|
||||||
|
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
|
||||||
|
|
||||||
|
// Tile the input tensor into blocks
|
||||||
|
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
|
||||||
|
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||||
|
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||||
|
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||||
|
|
||||||
|
// Construct a partitioning of the tile among threads with the given thread arrangement.
|
||||||
|
|
||||||
|
// Concept: Tensor ThrLayout ThrIndex
|
||||||
|
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
|
||||||
|
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
|
||||||
|
Tensor thr_tile_P = local_partition(tile_P, ThreadLayout{}, threadIdx.x);
|
||||||
|
|
||||||
|
// Copy from GMEM to GMEM using `thr_tile_P` to guard accesses.
|
||||||
|
copy_if(thr_tile_P, thr_tile_S, thr_tile_D);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Vectorized copy kernel.
|
||||||
|
///
|
||||||
|
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
|
||||||
|
/// has the precondition that pointers are aligned to the vector size.
|
||||||
|
///
|
||||||
|
template <class TensorS, class TensorD, class BlockShape, class Tiled_Copy>
|
||||||
|
__global__ void copy_if_kernel_vectorized(TensorS S, TensorD D, BlockShape block_shape, Tiled_Copy tiled_copy)
|
||||||
|
{
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
// Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D.
|
||||||
|
auto shape_S = shape(S);
|
||||||
|
Tensor C = make_identity_tensor(shape_S);
|
||||||
|
// Construct a predicate tensor which compares the coordinates with the original shape
|
||||||
|
Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); });
|
||||||
|
|
||||||
|
// Tile the input tensor into blocks
|
||||||
|
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
|
||||||
|
Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||||
|
Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||||
|
Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Construct a Tensor corresponding to each thread's slice.
|
||||||
|
//
|
||||||
|
ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||||
|
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CPY, CPY_M, CPY_N)
|
||||||
|
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CPY, CPY_M, CPY_N)
|
||||||
|
Tensor thr_tile_P = thr_copy.partition_S(tile_P); // (CPY, CPY_M, CPY_N)
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// Copy from GMEM to GMEM
|
||||||
|
copy_if(tiled_copy, thr_tile_P, thr_tile_S, thr_tile_D);
|
||||||
|
#else
|
||||||
|
// make_fragment_like() constructs a tensor in RMEM with the same shape as thr_tile_S.
|
||||||
|
Tensor frag = make_fragment_like(thr_tile_S);
|
||||||
|
|
||||||
|
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||||
|
copy_if(tiled_copy, thr_tile_P, thr_tile_S, frag);
|
||||||
|
copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main function
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
//
|
||||||
|
// Given a 2D shape, perform an efficient copy
|
||||||
|
//
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using Element = float;
|
||||||
|
|
||||||
|
// Define a tensor shape with dynamic extents (m, n)
|
||||||
|
auto tensor_shape = make_shape(528, 300);
|
||||||
|
|
||||||
|
thrust::host_vector<Element> h_S(size(tensor_shape));
|
||||||
|
thrust::host_vector<Element> h_D(size(tensor_shape));
|
||||||
|
|
||||||
|
//
|
||||||
|
// Initialize
|
||||||
|
//
|
||||||
|
|
||||||
|
for (size_t i = 0; i < h_S.size(); ++i) {
|
||||||
|
h_S[i] = static_cast<Element>(i);
|
||||||
|
h_D[i] = Element{};
|
||||||
|
}
|
||||||
|
|
||||||
|
thrust::device_vector<Element> d_S = h_S;
|
||||||
|
thrust::device_vector<Element> d_D = h_D;
|
||||||
|
thrust::device_vector<Element> d_Zero = h_D;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Make tensors
|
||||||
|
//
|
||||||
|
|
||||||
|
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
|
||||||
|
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
|
||||||
|
|
||||||
|
//
|
||||||
|
// Partition
|
||||||
|
//
|
||||||
|
|
||||||
|
// Define a statically sized block (M, N).
|
||||||
|
//
|
||||||
|
// Note, by convention, capital letters are used to represent static modes.
|
||||||
|
auto block_shape = make_shape(Int<128>{}, Int<64>{});
|
||||||
|
|
||||||
|
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
|
||||||
|
// shape, and modes (m', n') correspond to the number of tiles.
|
||||||
|
//
|
||||||
|
// These will be used to determine the CUDA kernel grid dimensinos.
|
||||||
|
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
|
||||||
|
|
||||||
|
// Describes the layout of threads which is then replicated to tile 'block_shape.'
|
||||||
|
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (ThrM, ThrN)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Determine grid and block dimensions
|
||||||
|
//
|
||||||
|
|
||||||
|
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
|
||||||
|
dim3 blockDim(size(thr_layout));
|
||||||
|
|
||||||
|
//
|
||||||
|
// Launch the kernel
|
||||||
|
//
|
||||||
|
|
||||||
|
// copy_if()
|
||||||
|
copy_if_kernel<<< gridDim, blockDim >>>(
|
||||||
|
tensor_S,
|
||||||
|
tensor_D,
|
||||||
|
block_shape,
|
||||||
|
thr_layout);
|
||||||
|
|
||||||
|
cudaError result = cudaDeviceSynchronize();
|
||||||
|
|
||||||
|
if (result != cudaSuccess) {
|
||||||
|
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
h_D = d_D;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Verification
|
||||||
|
//
|
||||||
|
|
||||||
|
auto verify = [](thrust::host_vector<Element> const &S, thrust::host_vector<Element> const &D){
|
||||||
|
|
||||||
|
int32_t errors = 0;
|
||||||
|
int32_t const kErrorLimit = 10;
|
||||||
|
|
||||||
|
if (S.size() != D.size()) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < D.size(); ++i) {
|
||||||
|
if (S[i] != D[i]) {
|
||||||
|
std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << i << "]: " << D[i] << std::endl;
|
||||||
|
|
||||||
|
if (++errors >= kErrorLimit) {
|
||||||
|
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
|
||||||
|
return errors;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (verify(h_D, h_S)) {
|
||||||
|
return -1;
|
||||||
|
} else {
|
||||||
|
std::cout << "Success." << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
thrust::copy(d_Zero.begin(), d_Zero.end(), d_D.begin());
|
||||||
|
|
||||||
|
// Construct a TiledCopy with a specific access pattern.
|
||||||
|
// This version uses a
|
||||||
|
// (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc),
|
||||||
|
// (2) Layout-of-Values that each thread will access.
|
||||||
|
|
||||||
|
// Value arrangement per thread
|
||||||
|
Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx
|
||||||
|
|
||||||
|
// Define `AccessType` which controls the size of the actual memory access instruction.
|
||||||
|
using CopyOp = UniversalCopy<uint_byte_t<sizeof(Element) * size(val_layout)>>; // A very specific access width copy instruction
|
||||||
|
//using CopyOp = UniversalCopy<cutlass::AlignedArray<Element, size(val_layout)>>; // A more generic type that supports many copy strategies
|
||||||
|
//using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs
|
||||||
|
|
||||||
|
// A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element.
|
||||||
|
using Atom = Copy_Atom<CopyOp, Element>;
|
||||||
|
|
||||||
|
// Construct tiled copy, a tiling of copy atoms.
|
||||||
|
//
|
||||||
|
// Note, this assumes the vector and thread layouts are aligned with contigous data
|
||||||
|
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
|
||||||
|
// reads. Alternative value layouts are also possible, though incompatible layouts
|
||||||
|
// will result in compile time errors.
|
||||||
|
TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy
|
||||||
|
thr_layout, // thread layout (e.g. 32x4 Col-Major)
|
||||||
|
val_layout); // value layout (e.g. 4x1)
|
||||||
|
|
||||||
|
// copy_if() with vectorization
|
||||||
|
copy_if_kernel_vectorized<<< gridDim, blockDim >>>(
|
||||||
|
tensor_S,
|
||||||
|
tensor_D,
|
||||||
|
block_shape,
|
||||||
|
tiled_copy);
|
||||||
|
|
||||||
|
result = cudaDeviceSynchronize();
|
||||||
|
|
||||||
|
if (result != cudaSuccess) {
|
||||||
|
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
h_D = d_D;
|
||||||
|
|
||||||
|
if (verify(h_D, h_S)) {
|
||||||
|
return -1;
|
||||||
|
} else {
|
||||||
|
std::cout << "Success." << std::endl;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
200
examples/python/CuTeDSL/ampere/smem_allocator.py
Normal file
200
examples/python/CuTeDSL/ampere/smem_allocator.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
import cutlass.cute as cute
|
||||||
|
import cutlass
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from cutlass.cute.runtime import from_dlpack
|
||||||
|
|
||||||
|
"""
|
||||||
|
A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL.
|
||||||
|
|
||||||
|
This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL.
|
||||||
|
It shows various ways to allocate different data structures in shared memory:
|
||||||
|
|
||||||
|
1. Struct allocation with natural and strict alignment
|
||||||
|
2. Raw memory block allocation with custom alignment
|
||||||
|
3. Array allocation with automatic alignment
|
||||||
|
4. Tensor allocation with layout specification
|
||||||
|
|
||||||
|
The example includes:
|
||||||
|
- Shared storage struct with mixed alignment requirements
|
||||||
|
- Memory allocation patterns for different data types
|
||||||
|
- Tensor operations on allocated memory
|
||||||
|
|
||||||
|
To run this example:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python examples/ampere/smem_allocator.py
|
||||||
|
|
||||||
|
The example will allocate shared memory, perform tensor operations, and verify the results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@cute.struct
|
||||||
|
class complex:
|
||||||
|
real: cutlass.Float32
|
||||||
|
imag: cutlass.Float32
|
||||||
|
|
||||||
|
|
||||||
|
# SharedStorage size is 512, alignment is 128
|
||||||
|
@cute.struct
|
||||||
|
class SharedStorage:
|
||||||
|
# struct elements with natural alignment
|
||||||
|
a: cute.struct.MemRange[cutlass.Float32, 32] # array
|
||||||
|
b: cutlass.Int64 # saclar
|
||||||
|
c: complex # nested struct
|
||||||
|
# struct elements with strict alignment
|
||||||
|
x: cute.struct.Align[
|
||||||
|
cute.struct.MemRange[cutlass.Float32, 32],
|
||||||
|
128,
|
||||||
|
]
|
||||||
|
y: cute.struct.Align[cutlass.Int32, 8]
|
||||||
|
z: cute.struct.Align[complex, 16]
|
||||||
|
|
||||||
|
|
||||||
|
@cute.kernel
|
||||||
|
def kernel(
|
||||||
|
const_a: cutlass.Constexpr,
|
||||||
|
dst_a: cute.Tensor,
|
||||||
|
const_b: cutlass.Constexpr,
|
||||||
|
dst_b: cute.Tensor,
|
||||||
|
const_c: cutlass.Constexpr,
|
||||||
|
dst_c: cute.Tensor,
|
||||||
|
):
|
||||||
|
# Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize
|
||||||
|
# Note: alignment of inital allocator base ptr is 1024
|
||||||
|
allocator = cutlass.utils.SmemAllocator()
|
||||||
|
# base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory)
|
||||||
|
|
||||||
|
# -- Allocate a struct --
|
||||||
|
# Note: when specified alignment, max(alignment, alignof(struct)) will be applied
|
||||||
|
# reserves the section of struct in smem, elements in the struct can be accessed by ptr
|
||||||
|
struct_in_smem = allocator.allocate(SharedStorage)
|
||||||
|
# base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct)
|
||||||
|
|
||||||
|
# -- Allocate a block of memory --
|
||||||
|
# reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr
|
||||||
|
section_in_smem = allocator.allocate(64, byte_alignment=128)
|
||||||
|
# base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section)
|
||||||
|
|
||||||
|
# -- Allocate an array --
|
||||||
|
# reserves an int64 array of size 14 in smem, returns the array base ptr
|
||||||
|
array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14)
|
||||||
|
# base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array)
|
||||||
|
|
||||||
|
# -- Allocate a tensor --
|
||||||
|
# Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor
|
||||||
|
# Note: iterator swizzle with swizzle layout is currently not supported
|
||||||
|
layout = cute.make_layout((16, 2))
|
||||||
|
tensor_in_smem = allocator.allocate_tensor(
|
||||||
|
element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None
|
||||||
|
)
|
||||||
|
# base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor)
|
||||||
|
|
||||||
|
# ptr<f16, smem, align<1024>>
|
||||||
|
# ptr<i64, smem, align<128>>
|
||||||
|
# ptr<f32, smem, align<8>>
|
||||||
|
print(struct_in_smem.a.data_ptr())
|
||||||
|
print(struct_in_smem.b)
|
||||||
|
print(struct_in_smem.c.real)
|
||||||
|
# ptr<i8, smem, align<512>>
|
||||||
|
print(section_in_smem)
|
||||||
|
# ptr<i64, smem, align<64>>
|
||||||
|
print(array_in_smem)
|
||||||
|
# tensor<ptr<f16, smem, align<32>> o (16,4):(1,16)>
|
||||||
|
print(tensor_in_smem)
|
||||||
|
|
||||||
|
# fill MemRange tensor in struct and copy to dst
|
||||||
|
a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4)))
|
||||||
|
a_tensor.fill(const_a)
|
||||||
|
cute.printf("cute.struct.MemRange: {}", a_tensor)
|
||||||
|
dst_a.store(a_tensor.load())
|
||||||
|
|
||||||
|
# convert block of smem to fill tensor and copy to dst
|
||||||
|
layout = cute.make_layout((8, 2))
|
||||||
|
sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32)
|
||||||
|
sec_tensor = cute.make_tensor(sec_ptr, layout)
|
||||||
|
sec_tensor.fill(const_b)
|
||||||
|
cute.printf("block of memory: {}", sec_tensor)
|
||||||
|
dst_b.store(sec_tensor.load())
|
||||||
|
|
||||||
|
# fill allocated tensor in smem and copy to dst
|
||||||
|
tensor_in_smem.fill(const_c)
|
||||||
|
cute.printf("tensor in smem: {}", tensor_in_smem)
|
||||||
|
dst_c.store(tensor_in_smem.load())
|
||||||
|
|
||||||
|
|
||||||
|
@cute.jit
|
||||||
|
def run_allocation_kernel(
|
||||||
|
const_a: cutlass.Constexpr,
|
||||||
|
dst_a: cute.Tensor,
|
||||||
|
const_b: cutlass.Constexpr,
|
||||||
|
dst_b: cute.Tensor,
|
||||||
|
const_c: cutlass.Constexpr,
|
||||||
|
dst_c: cute.Tensor,
|
||||||
|
):
|
||||||
|
# additional size for the example, 64(section) + 112(array) + 128(tensor) < 384
|
||||||
|
addtional_bytes = 384
|
||||||
|
# Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes
|
||||||
|
kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch(
|
||||||
|
grid=(1, 1, 1),
|
||||||
|
block=(1, 1, 1),
|
||||||
|
smem=SharedStorage.size_in_bytes() + addtional_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def veify_allocation_kernel(const_a, const_b, const_c):
|
||||||
|
dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda")
|
||||||
|
dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda")
|
||||||
|
dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
run_allocation_kernel(
|
||||||
|
const_a,
|
||||||
|
from_dlpack(dst_a),
|
||||||
|
const_b,
|
||||||
|
from_dlpack(dst_b),
|
||||||
|
const_c,
|
||||||
|
from_dlpack(dst_c),
|
||||||
|
)
|
||||||
|
|
||||||
|
np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0])
|
||||||
|
np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0])
|
||||||
|
np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# prepare cuda context
|
||||||
|
cutlass.cuda.initialize_cuda_context()
|
||||||
|
# An example for shared memory allocation
|
||||||
|
const_a = 0.5
|
||||||
|
const_b = 1.0
|
||||||
|
const_c = 2.0
|
||||||
|
veify_allocation_kernel(const_a, const_b, const_c)
|
||||||
51
examples/python/CuTeDSL/cute/ffi/CMakeLists.txt
Normal file
51
examples/python/CuTeDSL/cute/ffi/CMakeLists.txt
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
cmake_minimum_required(VERSION 3.15)
|
||||||
|
project(tensor)
|
||||||
|
|
||||||
|
# Find Python
|
||||||
|
find_package(Python COMPONENTS Interpreter Development REQUIRED)
|
||||||
|
|
||||||
|
# Get Python site-packages directory using Python
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])"
|
||||||
|
OUTPUT_VARIABLE Python_SITE_PACKAGES
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
)
|
||||||
|
|
||||||
|
message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}")
|
||||||
|
|
||||||
|
# Add nanobind path to CMAKE_PREFIX_PATH
|
||||||
|
list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake)
|
||||||
|
|
||||||
|
# Find nanobind
|
||||||
|
find_package(nanobind REQUIRED)
|
||||||
|
|
||||||
|
# Add the module
|
||||||
|
nanobind_add_module(tensor tensor.cpp)
|
||||||
305
examples/python/CuTeDSL/cute/ffi/jit_argument.py
Normal file
305
examples/python/CuTeDSL/cute/ffi/jit_argument.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations.
|
||||||
|
|
||||||
|
This example demonstrates a basic approach to building customized interfaces as C-structures between user code
|
||||||
|
and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions
|
||||||
|
and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions.
|
||||||
|
|
||||||
|
The C-structure is defined as:
|
||||||
|
|
||||||
|
.. code-block:: c
|
||||||
|
|
||||||
|
struct Tensor {
|
||||||
|
void *ptr; // Pointer to tensor data
|
||||||
|
int32_t shape[3]; // Tensor dimensions
|
||||||
|
int32_t strides[3]; // Memory strides for each dimension
|
||||||
|
};
|
||||||
|
|
||||||
|
The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer,
|
||||||
|
shape, and strides, enabling efficient data passing between different language boundaries.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Future development may include automated code generation flows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cutlass
|
||||||
|
import cutlass.cute as cute
|
||||||
|
|
||||||
|
from cutlass._mlir import ir
|
||||||
|
from cutlass._mlir.dialects import llvm
|
||||||
|
import cutlass._mlir.extras.types as T
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleTensorValue(ir.Value):
|
||||||
|
"""A wrapper class for tensor values in MLIR.
|
||||||
|
|
||||||
|
This class extends ir.Value to provide convenient access to tensor data pointer,
|
||||||
|
shape, and strides through MLIR operations.
|
||||||
|
|
||||||
|
:type: ir.Value
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, v):
|
||||||
|
"""Initialize a new TensorValue.
|
||||||
|
|
||||||
|
:param v: The underlying MLIR value to wrap
|
||||||
|
:type v: ir.Value
|
||||||
|
"""
|
||||||
|
super().__init__(v)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_ptr(self, *, loc=None, ip=None):
|
||||||
|
"""Get the data pointer from the tensor value.
|
||||||
|
|
||||||
|
Extracts the data pointer (first field) from the LLVM struct value.
|
||||||
|
|
||||||
|
:param loc: Optional location information for MLIR operations
|
||||||
|
:type loc: Optional[ir.Location]
|
||||||
|
:param ip: Optional insertion point for MLIR operations
|
||||||
|
:type ip: Optional[ir.InsertionPoint]
|
||||||
|
:return: An integer value representing the data pointer
|
||||||
|
:rtype: ir.Value
|
||||||
|
"""
|
||||||
|
# Extract the data pointer from the LLVM struct value
|
||||||
|
# The data pointer is the first field (index 0) in the struct
|
||||||
|
|
||||||
|
# Use llvm.extractvalue to get the pointer field from the struct
|
||||||
|
ptr_val = llvm.extractvalue(
|
||||||
|
llvm.PointerType.get(),
|
||||||
|
self,
|
||||||
|
[0], # Extract the first field (index 0)
|
||||||
|
loc=loc,
|
||||||
|
ip=ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cute.make_ptr(cutlass.Float32, ptr_val)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
"""Get the shape of the tensor.
|
||||||
|
|
||||||
|
Extracts the shape (second field) from the LLVM struct value.
|
||||||
|
|
||||||
|
:return: A tuple of integers representing the tensor dimensions
|
||||||
|
:rtype: tuple[ir.Value, ...]
|
||||||
|
"""
|
||||||
|
i32_type = ir.IntegerType.get_signless(32)
|
||||||
|
# Extract the shape field from the LLVM struct value
|
||||||
|
# The shape is the second field (index 1) in the struct
|
||||||
|
shape_val = llvm.extractvalue(
|
||||||
|
llvm.StructType.get_literal([i32_type] * 3),
|
||||||
|
self,
|
||||||
|
[1], # Extract the second field (index 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract each dimension from the shape struct
|
||||||
|
return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
"""Get the strides of the tensor.
|
||||||
|
|
||||||
|
Extracts the strides (third field) from the LLVM struct value.
|
||||||
|
|
||||||
|
:return: A tuple of integers representing the tensor strides
|
||||||
|
:rtype: tuple[ir.Value, ...]
|
||||||
|
"""
|
||||||
|
i32_type = ir.IntegerType.get_signless(32)
|
||||||
|
# Extract the strides field from the LLVM struct value
|
||||||
|
# The strides are the third field (index 2) in the struct
|
||||||
|
strides_val = llvm.extractvalue(
|
||||||
|
llvm.StructType.get_literal([i32_type] * 3),
|
||||||
|
self,
|
||||||
|
[2], # Extract the third field (index 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract each dimension from the strides struct
|
||||||
|
return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3))
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleTensor:
|
||||||
|
"""A class representing a tensor with its data pointer, shape, and strides.
|
||||||
|
|
||||||
|
This class provides a Python interface to create and manipulate tensor structures
|
||||||
|
that can be passed to CUTE JIT compiled functions.
|
||||||
|
|
||||||
|
:ivar _c_struct_p: The C struct pointer for the tensor
|
||||||
|
:ivar _rank: The number of dimensions in the tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c_struct_p, rank):
|
||||||
|
"""Initialize a new Tensor.
|
||||||
|
|
||||||
|
:param c_struct_p: The C struct pointer for the tensor
|
||||||
|
:type c_struct_p: int
|
||||||
|
:param rank: The number of dimensions in the tensor
|
||||||
|
:type rank: int
|
||||||
|
"""
|
||||||
|
self._c_struct_p = c_struct_p
|
||||||
|
self._rank = rank
|
||||||
|
|
||||||
|
def __get_mlir_types__(self):
|
||||||
|
"""Get the MLIR types for this tensor.
|
||||||
|
|
||||||
|
Creates an LLVM structure type representing a C-structure with:
|
||||||
|
|
||||||
|
.. code-block:: c
|
||||||
|
|
||||||
|
struct Tensor {
|
||||||
|
void *ptr;
|
||||||
|
int32_t shape[3];
|
||||||
|
int32_t strides[3];
|
||||||
|
};
|
||||||
|
|
||||||
|
:return: A list containing the MLIR struct type
|
||||||
|
:rtype: list[llvm.StructType]
|
||||||
|
|
||||||
|
Create an LLVM structure type that represents a C-structure like:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the number of dimensions from the shape
|
||||||
|
ndim = self._rank
|
||||||
|
|
||||||
|
# Create the pointer type (void*)
|
||||||
|
ptr_type = llvm.PointerType.get()
|
||||||
|
|
||||||
|
# Create array types for shape and strides (int32_t[ndim])
|
||||||
|
int32_type = ir.IntegerType.get_signless(32)
|
||||||
|
shape_type = llvm.StructType.get_literal([int32_type] * ndim)
|
||||||
|
strides_type = llvm.StructType.get_literal([int32_type] * ndim)
|
||||||
|
|
||||||
|
# Create the structure type
|
||||||
|
struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type])
|
||||||
|
|
||||||
|
return [struct_type]
|
||||||
|
|
||||||
|
def __new_from_mlir_values__(self, values):
|
||||||
|
"""Create a new TensorValue from MLIR values.
|
||||||
|
|
||||||
|
:param values: A list of MLIR values
|
||||||
|
:type values: list[ir.Value]
|
||||||
|
:return: A new TensorValue instance
|
||||||
|
:rtype: TensorValue
|
||||||
|
"""
|
||||||
|
return ExampleTensorValue(values[0])
|
||||||
|
|
||||||
|
def __c_pointers__(self):
|
||||||
|
"""Get the C pointers for this tensor.
|
||||||
|
|
||||||
|
:return: A list containing the C struct pointer
|
||||||
|
:rtype: list[int]
|
||||||
|
"""
|
||||||
|
return [self._c_struct_p]
|
||||||
|
|
||||||
|
|
||||||
|
@cute.jit
|
||||||
|
def foo(tensor):
|
||||||
|
"""Example JIT function that prints tensor information.
|
||||||
|
|
||||||
|
:param tensor: A Tensor instance to print information about
|
||||||
|
:type tensor: Tensor
|
||||||
|
"""
|
||||||
|
cute.printf("data_ptr: {}", tensor.data_ptr)
|
||||||
|
cute.printf("shape: {}", tensor.shape)
|
||||||
|
cute.printf("stride: {}", tensor.stride)
|
||||||
|
|
||||||
|
mA = cute.make_tensor(
|
||||||
|
tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride)
|
||||||
|
)
|
||||||
|
cute.print_tensor(mA)
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(tmpdir=None):
|
||||||
|
# Skip cleanup if user provides tmpdir
|
||||||
|
cleanup = tmpdir is None
|
||||||
|
# Initialize temporary build directory
|
||||||
|
tmpdir = tmpdir or tempfile.mkdtemp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True)
|
||||||
|
subprocess.run(["cmake", "--build", tmpdir], check=True)
|
||||||
|
|
||||||
|
sys.path.append(tmpdir)
|
||||||
|
|
||||||
|
from tensor import make_tensor, pycapsule_get_pointer
|
||||||
|
|
||||||
|
# Mock test tensor and corresponding C structure for this example
|
||||||
|
# In production, this may come from external library
|
||||||
|
x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4)
|
||||||
|
c_struct = make_tensor(x.data_ptr(), x.shape, x.stride())
|
||||||
|
c_struct_p = pycapsule_get_pointer(c_struct)
|
||||||
|
|
||||||
|
# Initialize tensor wrapper and compile test function
|
||||||
|
tensor = ExampleTensor(c_struct_p, len(x.shape))
|
||||||
|
compiled_func = cute.compile(foo, tensor)
|
||||||
|
|
||||||
|
# Benchmark pointer access performance
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
# Measure performance of critical path pointer access
|
||||||
|
# get C pointers is on critical path to call JIT compiled function
|
||||||
|
for _ in range(1000):
|
||||||
|
tensor.__c_pointers__()
|
||||||
|
end = time()
|
||||||
|
print(f"__c_pointers__: {(end - start) * 1000} us")
|
||||||
|
|
||||||
|
# Execute compiled function
|
||||||
|
compiled_func(tensor)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
finally:
|
||||||
|
if cleanup:
|
||||||
|
# Clean up the temporary directory
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Set temporary directory for building C modules"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tmp-dir", type=str, help="Temporary directory path for building C modules"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
run_test(args.tmp_dir)
|
||||||
82
examples/python/CuTeDSL/cute/ffi/tensor.cpp
Normal file
82
examples/python/CuTeDSL/cute/ffi/tensor.cpp
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
// 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
// contributors may be used to endorse or promote products derived from
|
||||||
|
// this software without specific prior written permission.
|
||||||
|
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
// POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <nanobind/nanobind.h>
|
||||||
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
|
namespace nb = nanobind;
|
||||||
|
|
||||||
|
// Forward declaration of the MockTensor struct for testing only
|
||||||
|
struct MockTensor {
|
||||||
|
void *ptr;
|
||||||
|
struct {
|
||||||
|
int32_t shape[3];
|
||||||
|
} shape;
|
||||||
|
|
||||||
|
struct {
|
||||||
|
int32_t strides[3];
|
||||||
|
} strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
NB_MODULE(tensor, m) {
|
||||||
|
// create a tensor for testing
|
||||||
|
m.def("make_tensor", [](int64_t ptr, std::vector<int32_t> shape,
|
||||||
|
std::vector<int32_t> strides) {
|
||||||
|
auto *tensor = new MockTensor();
|
||||||
|
tensor->ptr = reinterpret_cast<void *>(ptr);
|
||||||
|
|
||||||
|
assert(shape.size() == 3 && "shape must have 3 elements");
|
||||||
|
assert(strides.size() == 3 && "strides must have 3 elements");
|
||||||
|
|
||||||
|
for (size_t i = 0; i < shape.size(); i++) {
|
||||||
|
tensor->shape.shape[i] = shape[i];
|
||||||
|
tensor->strides.strides[i] = strides[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return nb::steal(PyCapsule_New(tensor, "tensor", [](PyObject *capsule) {
|
||||||
|
auto n = PyCapsule_GetName(capsule);
|
||||||
|
if (void *p = PyCapsule_GetPointer(capsule, n)) {
|
||||||
|
delete reinterpret_cast<MockTensor *>(p);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"pycapsule_get_pointer",
|
||||||
|
[](nb::object &capsule) {
|
||||||
|
void *ptr = PyCapsule_GetPointer(capsule.ptr(), "tensor");
|
||||||
|
if (!ptr) {
|
||||||
|
throw std::runtime_error("Invalid tensor capsule");
|
||||||
|
}
|
||||||
|
return reinterpret_cast<uintptr_t>(ptr);
|
||||||
|
},
|
||||||
|
"Get pointer from PyCapsule");
|
||||||
|
}
|
||||||
1486
examples/python/CuTeDSL/hopper/dense_gemm.py
Normal file
1486
examples/python/CuTeDSL/hopper/dense_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -83,11 +83,6 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Print hello world from host code\n",
|
" # Print hello world from host code\n",
|
||||||
" cute.printf(\"hello world\")\n",
|
" cute.printf(\"hello world\")\n",
|
||||||
" \n",
|
|
||||||
" # Initialize CUDA context for launching a kernel with error checking\n",
|
|
||||||
" # We make context initialization explicit to allow users to control the context creation \n",
|
|
||||||
" # and avoid potential issues with multiple contexts\n",
|
|
||||||
" cutlass.cuda.initialize_cuda_context()\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" # Launch kernel\n",
|
" # Launch kernel\n",
|
||||||
" kernel().launch(\n",
|
" kernel().launch(\n",
|
||||||
@ -129,6 +124,11 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Initialize CUDA context for launching a kernel with error checking\n",
|
||||||
|
"# We make context initialization explicit to allow users to control the context creation \n",
|
||||||
|
"# and avoid potential issues with multiple contexts\n",
|
||||||
|
"cutlass.cuda.initialize_cuda_context()\n",
|
||||||
|
"\n",
|
||||||
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
|
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
|
||||||
"print(\"Running hello_world()...\")\n",
|
"print(\"Running hello_world()...\")\n",
|
||||||
"hello_world()\n",
|
"hello_world()\n",
|
||||||
@ -136,6 +136,7 @@
|
|||||||
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
|
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
|
||||||
"print(\"Compiling...\")\n",
|
"print(\"Compiling...\")\n",
|
||||||
"hello_world_compiled = cute.compile(hello_world)\n",
|
"hello_world_compiled = cute.compile(hello_world)\n",
|
||||||
|
"\n",
|
||||||
"# Run the pre-compiled version\n",
|
"# Run the pre-compiled version\n",
|
||||||
"print(\"Running compiled version...\")\n",
|
"print(\"Running compiled version...\")\n",
|
||||||
"hello_world_compiled()"
|
"hello_world_compiled()"
|
||||||
|
|||||||
@ -33,7 +33,6 @@
|
|||||||
#include <cute/config.hpp>
|
#include <cute/config.hpp>
|
||||||
|
|
||||||
#include <cute/tensor_impl.hpp>
|
#include <cute/tensor_impl.hpp>
|
||||||
#include <cute/tensor_predicate.hpp>
|
|
||||||
|
|
||||||
namespace cute
|
namespace cute
|
||||||
{
|
{
|
||||||
@ -45,7 +44,7 @@ template <class Alpha,
|
|||||||
class XEngine, class XLayout,
|
class XEngine, class XLayout,
|
||||||
class Beta,
|
class Beta,
|
||||||
class YEngine, class YLayout,
|
class YEngine, class YLayout,
|
||||||
class PrdTensor = TrivialPredTensor>
|
class PrdTensor = constant_fn<true_type>>
|
||||||
CUTE_HOST_DEVICE
|
CUTE_HOST_DEVICE
|
||||||
void
|
void
|
||||||
axpby(Alpha const& alpha,
|
axpby(Alpha const& alpha,
|
||||||
@ -64,7 +63,7 @@ template <class Alpha,
|
|||||||
class XEngine, class XLayout,
|
class XEngine, class XLayout,
|
||||||
class Beta,
|
class Beta,
|
||||||
class YEngine, class YLayout,
|
class YEngine, class YLayout,
|
||||||
class PrdTensor = TrivialPredTensor>
|
class PrdTensor = constant_fn<true_type>>
|
||||||
CUTE_HOST_DEVICE
|
CUTE_HOST_DEVICE
|
||||||
void
|
void
|
||||||
axpby(Alpha const& alpha,
|
axpby(Alpha const& alpha,
|
||||||
|
|||||||
@ -36,7 +36,6 @@
|
|||||||
#include <cute/swizzle.hpp> // cute::Swizzle
|
#include <cute/swizzle.hpp> // cute::Swizzle
|
||||||
#include <cute/swizzle_layout.hpp> // cute::get_nonswizzle_portion
|
#include <cute/swizzle_layout.hpp> // cute::get_nonswizzle_portion
|
||||||
#include <cute/tensor_impl.hpp> // cute::Tensor
|
#include <cute/tensor_impl.hpp> // cute::Tensor
|
||||||
#include <cute/tensor_predicate.hpp>
|
|
||||||
#include <cute/algorithm/copy.hpp>
|
#include <cute/algorithm/copy.hpp>
|
||||||
#include <cute/atom/copy_atom.hpp>
|
#include <cute/atom/copy_atom.hpp>
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,6 @@
|
|||||||
|
|
||||||
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
#include <cute/config.hpp> // CUTE_HOST_DEVICE
|
||||||
#include <cute/tensor_impl.hpp> // cute::Tensor
|
#include <cute/tensor_impl.hpp> // cute::Tensor
|
||||||
#include <cute/tensor_predicate.hpp> // cute::TrivialPredTensor
|
|
||||||
#include <cute/atom/copy_atom.hpp> // cute::Copy_Atom
|
#include <cute/atom/copy_atom.hpp> // cute::Copy_Atom
|
||||||
|
|
||||||
namespace cute
|
namespace cute
|
||||||
@ -66,10 +65,45 @@ copy_if(PrdTensor const& pred,
|
|||||||
// copy_if -- Predicated CopyAtom
|
// copy_if -- Predicated CopyAtom
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// Predicate Tensor is an Actual Tensor
|
||||||
|
template <class... CopyArgs,
|
||||||
|
class PrdEngine, class PrdLayout,
|
||||||
|
class SrcEngine, class SrcLayout,
|
||||||
|
class DstEngine, class DstLayout>
|
||||||
|
CUTE_HOST_DEVICE
|
||||||
|
void
|
||||||
|
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||||
|
Tensor<PrdEngine, PrdLayout> const& prd, // ([V],Rest...)
|
||||||
|
Tensor<SrcEngine, SrcLayout> const& src, // ( V, Rest...)
|
||||||
|
Tensor<DstEngine, DstLayout> & dst) // ( V, Rest...)
|
||||||
|
{
|
||||||
|
if constexpr (PrdLayout::rank == SrcLayout::rank - 1) {
|
||||||
|
// Back-compat ONLY -- Delete?
|
||||||
|
copy_if(copy_atom, make_tensor(prd.data(), prepend(prd.layout(), Layout<_1,_0>{})), src, dst);
|
||||||
|
} else {
|
||||||
|
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
|
||||||
|
static_assert(SrcLayout::rank == PrdLayout::rank, "CopyAtom rank-mismatch.");
|
||||||
|
|
||||||
|
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
|
||||||
|
copy_atom.call(prd, src, dst);
|
||||||
|
} else { // Loop over all but the first mode
|
||||||
|
constexpr int R = SrcLayout::rank;
|
||||||
|
Tensor prd_v = group_modes<1,R>(prd);
|
||||||
|
Tensor src_v = group_modes<1,R>(src);
|
||||||
|
Tensor dst_v = group_modes<1,R>(dst);
|
||||||
|
CUTE_UNROLL
|
||||||
|
for (int i = 0; i < size<1>(dst_v); ++i) {
|
||||||
|
copy_atom.call(prd_v(_,i), src_v(_,i), dst_v(_,i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <class... CopyArgs,
|
template <class... CopyArgs,
|
||||||
class PredTensor,
|
class PredTensor,
|
||||||
class SrcEngine, class SrcLayout,
|
class SrcEngine, class SrcLayout,
|
||||||
class DstEngine, class DstLayout>
|
class DstEngine, class DstLayout>
|
||||||
|
[[deprecated("Use a bool-tensor or transform-tensor as predication.")]]
|
||||||
CUTE_HOST_DEVICE
|
CUTE_HOST_DEVICE
|
||||||
void
|
void
|
||||||
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
||||||
@ -77,33 +111,14 @@ copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
|
|||||||
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
||||||
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
||||||
{
|
{
|
||||||
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
|
Tensor tpred = cute::lazy::transform(make_tensor(counting_iterator<int>{}, replace<0>(shape(dst), _1{})), pred);
|
||||||
auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom);
|
return copy_if(copy_atom, tpred, src, dst);
|
||||||
|
|
||||||
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
|
|
||||||
if constexpr (has_with_bool) {
|
|
||||||
copy_atom.with(pred()).call(src, dst);
|
|
||||||
} else {
|
|
||||||
if (pred()) { copy_atom.call(src, dst); }
|
|
||||||
}
|
|
||||||
} else { // Loop over all but the first mode
|
|
||||||
constexpr int R = SrcLayout::rank;
|
|
||||||
Tensor src_v = group_modes<1,R>(src);
|
|
||||||
Tensor dst_v = group_modes<1,R>(dst);
|
|
||||||
CUTE_UNROLL
|
|
||||||
for (int i = 0; i < size<1>(dst_v); ++i) {
|
|
||||||
if constexpr (has_with_bool) {
|
|
||||||
copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i));
|
|
||||||
} else {
|
|
||||||
if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// copy_if -- AutoCopyAsync
|
// copy_if -- AutoCopyAsync
|
||||||
//
|
//
|
||||||
|
|
||||||
template <class PrdTensor,
|
template <class PrdTensor,
|
||||||
class SrcEngine, class SrcLayout,
|
class SrcEngine, class SrcLayout,
|
||||||
class DstEngine, class DstLayout>
|
class DstEngine, class DstLayout>
|
||||||
@ -159,7 +174,7 @@ copy(AutoCopyAsync const& cpy,
|
|||||||
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
|
||||||
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
|
||||||
{
|
{
|
||||||
copy_if(cpy, TrivialPredTensor{}, src, dst);
|
copy_if(cpy, constant_fn<true_type>{}, src, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -202,7 +217,7 @@ copy(Copy_Atom<CopyArgs...> const& copy_atom,
|
|||||||
Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
|
Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
|
||||||
Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
|
Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c));
|
CUTE_STATIC_ASSERT_V( size<1>(src_c) == size<1>(dst_c));
|
||||||
CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst));
|
CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst));
|
||||||
CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src));
|
CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src));
|
||||||
|
|
||||||
@ -224,7 +239,7 @@ copy(Copy_Atom<CopyArgs...> const& copy_atom,
|
|||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// Specialization for AutoVectorizingCopyAssumedAlignment<MaxVecBits>
|
// Specialization for AutoVectorizingCopyAssumedAlignment<MaxVecBits>
|
||||||
template <int MaxVecBits, class... Args,
|
template <int MaxVecBits,
|
||||||
class SrcEngine, class SrcLayout,
|
class SrcEngine, class SrcLayout,
|
||||||
class DstEngine, class DstLayout>
|
class DstEngine, class DstLayout>
|
||||||
CUTE_HOST_DEVICE
|
CUTE_HOST_DEVICE
|
||||||
@ -234,23 +249,30 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
|
|||||||
Tensor<DstEngine, DstLayout> & dst)
|
Tensor<DstEngine, DstLayout> & dst)
|
||||||
{
|
{
|
||||||
constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst));
|
constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst));
|
||||||
constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{}));
|
static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename DstEngine::value_type>)>::value, "Error: Attempting a subbit write!");
|
||||||
static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename SrcEngine::value_type>)>::value, "Error: Attempting a subbit copy!");
|
|
||||||
constexpr int vec_bits = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);
|
|
||||||
|
|
||||||
if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) {
|
if constexpr (common_elem > 1)
|
||||||
// If more than one element vectorizes to 8bits or more, then recast and copy
|
{
|
||||||
using VecType = uint_bit_t<vec_bits>;
|
constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{}));
|
||||||
// Preserve volatility
|
constexpr int vec_bits = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);
|
||||||
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
|
|
||||||
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
|
|
||||||
|
|
||||||
// Recast
|
if constexpr ((vec_bits % 8) == 0)
|
||||||
Tensor src_v = recast<SrcVecType>(src);
|
{
|
||||||
Tensor dst_v = recast<DstVecType>(dst);
|
// If more than one element vectorizes to 8bits or more, then recast and copy
|
||||||
return copy_if(TrivialPredTensor{}, src_v, dst_v);
|
using VecType = uint_bit_t<vec_bits>;
|
||||||
|
// Preserve volatility
|
||||||
|
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
|
||||||
|
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
|
||||||
|
|
||||||
|
// Recast
|
||||||
|
Tensor src_v = recast<SrcVecType>(src);
|
||||||
|
Tensor dst_v = recast<DstVecType>(dst);
|
||||||
|
return copy_if(constant_fn<true_type>{}, src_v, dst_v);
|
||||||
|
} else {
|
||||||
|
return copy_if(constant_fn<true_type>{}, src, dst);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return copy_if(TrivialPredTensor{}, src, dst);
|
return copy_if(constant_fn<true_type>{}, src, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -277,7 +299,7 @@ copy(AutoFilter<CopyOp> const& copy_op,
|
|||||||
Tensor src_n = zipped_divide(src, dst_null);
|
Tensor src_n = zipped_divide(src, dst_null);
|
||||||
|
|
||||||
CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error");
|
CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error");
|
||||||
CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy");
|
CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous race-condition detected.");
|
||||||
|
|
||||||
copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_));
|
copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_));
|
||||||
} else {
|
} else {
|
||||||
@ -335,6 +357,18 @@ copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, Args...> con
|
|||||||
return copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, src, dst);
|
return copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, src, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int MaxVecBits, class... Args,
|
||||||
|
class SrcEngine, class SrcLayout,
|
||||||
|
class DstEngine, class DstLayout>
|
||||||
|
CUTE_HOST_DEVICE
|
||||||
|
void
|
||||||
|
copy(Copy_Atom<Copy_Traits<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>>, Args...> const&,
|
||||||
|
Tensor<SrcEngine, SrcLayout> const& src,
|
||||||
|
Tensor<DstEngine, DstLayout> & dst)
|
||||||
|
{
|
||||||
|
return copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, src, dst);
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||||
template <class... CT_Args,
|
template <class... CT_Args,
|
||||||
class SrcEngine, class SrcLayout,
|
class SrcEngine, class SrcLayout,
|
||||||
@ -375,8 +409,8 @@ template <class... CT_Args, class... CA_Args,
|
|||||||
CUTE_HOST_DEVICE
|
CUTE_HOST_DEVICE
|
||||||
void
|
void
|
||||||
copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
|
copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
|
||||||
Tensor<SrcEngine, SrcLayout> const& src,
|
Tensor<SrcEngine, SrcLayout> const& src,
|
||||||
Tensor<DstEngine, DstLayout> & dst)
|
Tensor<DstEngine, DstLayout> & dst)
|
||||||
{
|
{
|
||||||
return copy(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src, dst);
|
return copy(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src, dst);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -90,18 +90,19 @@ constexpr bool has_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = true;
|
|||||||
|
|
||||||
} // end namespace detail
|
} // end namespace detail
|
||||||
|
|
||||||
template <class CopyOp, class... CT_Args, class... CA_Args,
|
template <class CopyOp, class... CT_Args, class CopyType,
|
||||||
class GEngine, class GLayout>
|
class GEngine, class GLayout>
|
||||||
CUTE_HOST_DEVICE
|
CUTE_HOST_DEVICE
|
||||||
void
|
void
|
||||||
prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CA_Args...> const& atom,
|
prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CopyType> const& atom,
|
||||||
Tensor<GEngine, GLayout> const& src)
|
Tensor<GEngine, GLayout> const& src)
|
||||||
{
|
{
|
||||||
if constexpr (detail::has_prefetch<CopyOp>) {
|
if constexpr (detail::has_prefetch<CopyOp>) {
|
||||||
using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>;
|
using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>;
|
||||||
using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CA_Args...>;
|
using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CopyType>;
|
||||||
Prefetch_Atom prefetch_atom{atom};
|
Prefetch_Atom prefetch_atom{atom};
|
||||||
auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); // dst is ignored for prefetch atoms
|
//auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); // dst is ignored for prefetch atoms
|
||||||
|
Tensor dst = make_tensor(make_smem_ptr<CopyType>(nullptr), shape(src));
|
||||||
return copy(prefetch_atom, src, dst);
|
return copy(prefetch_atom, src, dst);
|
||||||
} else {
|
} else {
|
||||||
return prefetch(src);
|
return prefetch(src);
|
||||||
|
|||||||
@ -163,4 +163,16 @@ transform(Tensor<EngineIn1,LayoutIn1> const& tensor_in1,
|
|||||||
return transform(tensor_in1, tensor_in2, tensor_out, op);
|
return transform(tensor_in1, tensor_in2, tensor_out, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
template <class Engine, class Layout, class Fn>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
auto
|
||||||
|
transform(cute::Tensor<Engine,Layout> const& t, Fn const& fn)
|
||||||
|
{
|
||||||
|
return cute::make_tensor(cute::make_transform_iter(fn, t.data()), t.layout());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace lazy
|
||||||
|
|
||||||
} // end namespace cute
|
} // end namespace cute
|
||||||
|
|||||||
107
include/cute/algorithm/tensor_reduce.hpp
Normal file
107
include/cute/algorithm/tensor_reduce.hpp
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include <cute/config.hpp>
|
||||||
|
#include <cute/tensor_impl.hpp>
|
||||||
|
#include <cute/algorithm/functional.hpp>
|
||||||
|
#include <cute/algorithm/fill.hpp>
|
||||||
|
|
||||||
|
namespace cute
|
||||||
|
{
|
||||||
|
|
||||||
|
// Reduce @src tensor using binary reduction operator @op and initial value @init and return a scalar.
|
||||||
|
template <class SrcEngine, class SrcLayout, class T, class BinaryOp = cute::plus>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
T
|
||||||
|
reduce(Tensor<SrcEngine,SrcLayout> const& src, T init, BinaryOp op = {})
|
||||||
|
{
|
||||||
|
for (auto i = 0; i < size(src); ++i) {
|
||||||
|
init = op(init, src(i));
|
||||||
|
}
|
||||||
|
return init;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce @src tensor RedMode using binary reduction operator @op and store the result in @dst tensor
|
||||||
|
// for each index in @dst/BatchMode.
|
||||||
|
// @pre @src tensor has rank 2
|
||||||
|
// @pre size of @src batch mode is equal to size of @dst batch mode
|
||||||
|
template <class SrcEngine, class SrcLayout,
|
||||||
|
class DstEngine, class DstLayout,
|
||||||
|
class BinaryOp = cute::plus>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
void
|
||||||
|
batch_reduce(Tensor<SrcEngine, SrcLayout> const& src, // (RedMode, BatchMode)
|
||||||
|
Tensor<DstEngine, DstLayout> & dst, // (BatchMode)
|
||||||
|
BinaryOp op = {})
|
||||||
|
{
|
||||||
|
// Precondition
|
||||||
|
CUTE_STATIC_ASSERT_V(rank(src) == Int<2>{});
|
||||||
|
assert(size<1>(src) == size(dst));
|
||||||
|
|
||||||
|
for (int i = 0; i < size(dst); ++i) {
|
||||||
|
dst(i) = reduce(src(_,i), dst(i), op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Reduce @src tensor along selected modes specified in @target_profile using binary reduction operator @op
|
||||||
|
// and store the result in @dst tensor. @target_profile is a tuple where '_' indicates modes to keep and
|
||||||
|
// integers indicates modes to reduce.
|
||||||
|
// @pre @target_profile is compatible with @src layout
|
||||||
|
template <class SrcEngine, class SrcLayout,
|
||||||
|
class DstEngine, class DstLayout,
|
||||||
|
class TargetProfile,
|
||||||
|
class BinaryOp = cute::plus>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
void
|
||||||
|
logical_reduce(Tensor<SrcEngine, SrcLayout> const& src,
|
||||||
|
Tensor<DstEngine, DstLayout> & dst,
|
||||||
|
TargetProfile const& target_profile,
|
||||||
|
BinaryOp op = {})
|
||||||
|
{
|
||||||
|
// Precondition
|
||||||
|
assert(compatible(target_profile, shape(src)));
|
||||||
|
|
||||||
|
auto diced_layout = dice(target_profile, src.layout());
|
||||||
|
auto sliced_layout = slice(target_profile, src.layout());
|
||||||
|
|
||||||
|
auto red_mode = conditional_return<rank(diced_layout) == Int<0>{}>(Layout<_1,_0>{}, diced_layout);
|
||||||
|
auto batch_mode = conditional_return<rank(sliced_layout) == Int<0>{}>(Layout<_1,_0>{}, sliced_layout);
|
||||||
|
|
||||||
|
auto src_tensor = make_tensor(src.data(), make_layout(red_mode, batch_mode));
|
||||||
|
|
||||||
|
batch_reduce(src_tensor, dst, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace cute
|
||||||
@ -123,6 +123,56 @@ struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
|
|||||||
{
|
{
|
||||||
return call(src, dst);
|
return call(src, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check and call instruction, or recurse
|
||||||
|
template <class PEngine, class PLayout,
|
||||||
|
class SEngine, class SLayout,
|
||||||
|
class DEngine, class DLayout>
|
||||||
|
CUTE_HOST_DEVICE
|
||||||
|
void
|
||||||
|
call(Tensor<PEngine,PLayout> const& prd,
|
||||||
|
Tensor<SEngine,SLayout> const& src,
|
||||||
|
Tensor<DEngine,DLayout> & dst) const
|
||||||
|
{
|
||||||
|
static_assert(PLayout::rank == 1, "Expected rank-1 prd tensor");
|
||||||
|
static_assert(SLayout::rank == 1, "Expected rank-1 src tensor");
|
||||||
|
static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor");
|
||||||
|
|
||||||
|
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value ||
|
||||||
|
is_constant<NumValDst, decltype(size(dst))>::value) {
|
||||||
|
// Dispatch to unpack to execute instruction
|
||||||
|
Traits const& traits = static_cast<Traits const&>(*this);
|
||||||
|
auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(t.with(true))>{}, traits);
|
||||||
|
if constexpr (has_with_bool) {
|
||||||
|
copy_unpack(traits.with(prd(Int<0>{})), src, dst);
|
||||||
|
} else {
|
||||||
|
if (prd(Int<0>{})) { copy_unpack(traits, src, dst); }
|
||||||
|
}
|
||||||
|
} else if constexpr (is_tuple<decltype(shape(prd))>::value &&
|
||||||
|
is_tuple<decltype(shape(src))>::value &&
|
||||||
|
is_tuple<decltype(shape(dst))>::value) {
|
||||||
|
// If the size of the src/dst doesn't match the instruction,
|
||||||
|
// recurse this rank-1 layout by peeling off the mode
|
||||||
|
// ((A,B,C,...)) -> (A,B,C,...)
|
||||||
|
return copy_if(*this, tensor<0>(prd), tensor<0>(src), tensor<0>(dst));
|
||||||
|
} else {
|
||||||
|
static_assert(dependent_false<SEngine>,
|
||||||
|
"CopyAtom: Src/Dst partitioning does not match the instruction requirement.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept mutable temporaries
|
||||||
|
template <class PEngine, class PLayout,
|
||||||
|
class SEngine, class SLayout,
|
||||||
|
class DEngine, class DLayout>
|
||||||
|
CUTE_HOST_DEVICE
|
||||||
|
void
|
||||||
|
call(Tensor<PEngine,PLayout> const& prd,
|
||||||
|
Tensor<SEngine,SLayout> const& src,
|
||||||
|
Tensor<DEngine,DLayout> && dst) const
|
||||||
|
{
|
||||||
|
return call(prd, src, dst);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -733,13 +783,13 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and
|
|||||||
#include <cute/atom/copy_traits_sm75.hpp>
|
#include <cute/atom/copy_traits_sm75.hpp>
|
||||||
#include <cute/atom/copy_traits_sm80.hpp>
|
#include <cute/atom/copy_traits_sm80.hpp>
|
||||||
#include <cute/atom/copy_traits_sm90.hpp>
|
#include <cute/atom/copy_traits_sm90.hpp>
|
||||||
#include <cute/atom/copy_traits_sm100.hpp>
|
#include <cute/atom/copy_traits_sm100.hpp>
|
||||||
|
|
||||||
|
|
||||||
// Config
|
// Config
|
||||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||||
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED
|
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED
|
||||||
# define CUTE_COPY_ATOM_TMA_SM100_ENABLED
|
# define CUTE_COPY_ATOM_TMA_SM100_ENABLED
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -235,6 +235,63 @@ raw_pointer_cast(counting_iterator<T> const& x) {
|
|||||||
return x.n_;
|
return x.n_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// transform_iterator
|
||||||
|
//
|
||||||
|
|
||||||
|
template <class Fn, class Iter>
|
||||||
|
struct transform_iter
|
||||||
|
{
|
||||||
|
using iterator = Iter;
|
||||||
|
// using reference = typename iterator_traits<iterator>::reference;
|
||||||
|
// using element_type = typename iterator_traits<iterator>::element_type;
|
||||||
|
// using value_type = typename iterator_traits<iterator>::value_type;
|
||||||
|
|
||||||
|
Fn fn_;
|
||||||
|
iterator ptr_;
|
||||||
|
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
transform_iter(Fn fn, iterator ptr = {}) : fn_(fn), ptr_(ptr) {}
|
||||||
|
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
decltype(auto) operator*() const { return fn_(*ptr_); }
|
||||||
|
|
||||||
|
template <class Index>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
decltype(auto) operator[](Index const& i) const { return fn_(ptr_[i]); }
|
||||||
|
|
||||||
|
template <class Index>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
auto operator+(Index const& i) const { return transform_iter<Fn, decltype(ptr_+i)>{fn_, ptr_+i}; }
|
||||||
|
|
||||||
|
template <class IterY>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
friend bool operator==(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ == y.ptr_; }
|
||||||
|
template <class IterY>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
friend bool operator!=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ != y.ptr_; }
|
||||||
|
template <class IterY>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
friend bool operator< (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ < y.ptr_; }
|
||||||
|
template <class IterY>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
friend bool operator<=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ <= y.ptr_; }
|
||||||
|
template <class IterY>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
friend bool operator> (transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ > y.ptr_; }
|
||||||
|
template <class IterY>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
friend bool operator>=(transform_iter<Fn,Iter> const& x, transform_iter<Fn,IterY> const& y) { return x.ptr_ >= y.ptr_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Fn, class Iterator>
|
||||||
|
CUTE_HOST_DEVICE constexpr
|
||||||
|
auto
|
||||||
|
make_transform_iter(Fn const& fn, Iterator const& ptr)
|
||||||
|
{
|
||||||
|
return transform_iter<Fn,Iterator>(fn,ptr);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Display utilities
|
// Display utilities
|
||||||
//
|
//
|
||||||
@ -251,12 +308,24 @@ CUTE_HOST_DEVICE void print(counting_iterator<T> ptr)
|
|||||||
printf("counting_iter("); print(ptr.n_); printf(")");
|
printf("counting_iter("); print(ptr.n_); printf(")");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class Fn, class Iterator>
|
||||||
|
CUTE_HOST_DEVICE void print(transform_iter<Fn,Iterator> ptr)
|
||||||
|
{
|
||||||
|
printf("trans_"); print(ptr.ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
#if !defined(__CUDACC_RTC__)
|
#if !defined(__CUDACC_RTC__)
|
||||||
template <class T>
|
template <class T>
|
||||||
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
|
CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator<T> ptr)
|
||||||
{
|
{
|
||||||
return os << "counting_iter(" << ptr.n_ << ")";
|
return os << "counting_iter(" << ptr.n_ << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class Fn, class Iterator>
|
||||||
|
CUTE_HOST std::ostream& operator<<(std::ostream& os, transform_iter<Fn,Iterator> ptr)
|
||||||
|
{
|
||||||
|
return os << "trans_" << ptr.ptr_;
|
||||||
|
}
|
||||||
#endif // !defined(__CUDACC_RTC__)
|
#endif // !defined(__CUDACC_RTC__)
|
||||||
|
|
||||||
} // end namespace cute
|
} // end namespace cute
|
||||||
|
|||||||
@ -41,7 +41,8 @@
|
|||||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||||
|
|
||||||
#ifndef CUTLASS_GDC_ENABLED
|
#ifndef CUTLASS_GDC_ENABLED
|
||||||
#if (defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \
|
#if (CUDA_BARRIER_ENABLED && \
|
||||||
|
defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \
|
||||||
__CUDACC_VER_MAJOR__ >= 12 && \
|
__CUDACC_VER_MAJOR__ >= 12 && \
|
||||||
defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
||||||
#define CUTLASS_GDC_ENABLED
|
#define CUTLASS_GDC_ENABLED
|
||||||
|
|||||||
@ -43,7 +43,6 @@
|
|||||||
#include "cute/arch/cluster_sm90.hpp"
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
#include "cute/atom/mma_atom.hpp"
|
#include "cute/atom/mma_atom.hpp"
|
||||||
#include "cute/algorithm/gemm.hpp"
|
#include "cute/algorithm/gemm.hpp"
|
||||||
#include "cute/tensor_predicate.hpp"
|
|
||||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||||
#include "cutlass/trace.h"
|
#include "cutlass/trace.h"
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,6 @@
|
|||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
#include "cute/tensor_predicate.hpp"
|
|
||||||
#include "cute/arch/cluster_sm90.hpp"
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
#include "cute/arch/copy_sm90.hpp"
|
#include "cute/arch/copy_sm90.hpp"
|
||||||
#include "cute/atom/mma_atom.hpp"
|
#include "cute/atom/mma_atom.hpp"
|
||||||
@ -103,7 +102,7 @@ struct CollectiveConv<
|
|||||||
|
|
||||||
using PipelineParams = typename MainloopPipeline::Params;
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
|
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||||
|
|
||||||
using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
|
using ProblemShape = ConvProblemShape<ConvOp, NumSpatialDimensions>;
|
||||||
|
|
||||||
static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
|
static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)");
|
||||||
@ -332,7 +331,7 @@ public:
|
|||||||
TmaTransactionBytes
|
TmaTransactionBytes
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ProblemShape>
|
template <class ProblemShape>
|
||||||
static bool
|
static bool
|
||||||
can_implement(
|
can_implement(
|
||||||
@ -409,7 +408,7 @@ public:
|
|||||||
if constexpr (ConvOp == conv::Operator::kWgrad) {
|
if constexpr (ConvOp == conv::Operator::kWgrad) {
|
||||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
#endif
|
#endif
|
||||||
const auto & input_shape = problem_shape.shape_A;
|
const auto & input_shape = problem_shape.shape_A;
|
||||||
const auto & input_stride = problem_shape.stride_A;
|
const auto & input_stride = problem_shape.stride_A;
|
||||||
|
|
||||||
@ -431,11 +430,11 @@ public:
|
|||||||
<< "\n input_shape: " << input_shape
|
<< "\n input_shape: " << input_shape
|
||||||
<< "\n input_stride: " << input_stride
|
<< "\n input_stride: " << input_stride
|
||||||
<< "\n";
|
<< "\n";
|
||||||
#endif
|
#endif
|
||||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n");
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n");
|
||||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||||
CUTLASS_TRACE_HOST(os.str());
|
CUTLASS_TRACE_HOST(os.str());
|
||||||
#endif
|
#endif
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -464,7 +463,7 @@ public:
|
|||||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
|
||||||
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
||||||
CUTLASS_TRACE_HOST(os.str());
|
CUTLASS_TRACE_HOST(os.str());
|
||||||
#endif
|
#endif
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,8 +515,8 @@ public:
|
|||||||
/// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k)
|
/// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k)
|
||||||
/// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k)
|
/// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k)
|
||||||
/// The rest of the tensors can be specified as needed by this collective.
|
/// The rest of the tensors can be specified as needed by this collective.
|
||||||
/// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with
|
/// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with
|
||||||
/// StrideA and StrideB set up for TMA
|
/// StrideA and StrideB set up for TMA
|
||||||
template <class ProblemShapeMNKL>
|
template <class ProblemShapeMNKL>
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE auto
|
||||||
load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){
|
load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){
|
||||||
|
|||||||
@ -303,6 +303,16 @@ public:
|
|||||||
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
||||||
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
||||||
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
|
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
|
||||||
|
// Dynamic cluster support
|
||||||
|
[[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0};
|
||||||
|
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
|
||||||
|
ConvKernel::ArchTag::kMinComputeCapability == 101) {
|
||||||
|
if constexpr (!cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape>) {
|
||||||
|
fallback_cluster = params.hw_info.cluster_shape_fallback;
|
||||||
|
cluster = params.hw_info.cluster_shape;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void* kernel_params[] = {¶ms};
|
void* kernel_params[] = {¶ms};
|
||||||
if constexpr (kEnableCudaHostAdapter) {
|
if constexpr (kEnableCudaHostAdapter) {
|
||||||
//
|
//
|
||||||
@ -313,6 +323,7 @@ public:
|
|||||||
|
|
||||||
launch_result = cuda_adapter->launch(grid,
|
launch_result = cuda_adapter->launch(grid,
|
||||||
cluster,
|
cluster,
|
||||||
|
fallback_cluster,
|
||||||
block,
|
block,
|
||||||
smem_size,
|
smem_size,
|
||||||
stream,
|
stream,
|
||||||
@ -338,6 +349,20 @@ public:
|
|||||||
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 ||
|
||||||
|
ConvKernel::ArchTag::kMinComputeCapability == 101) {
|
||||||
|
launch_result = ClusterLauncher::launch_with_fallback_cluster(
|
||||||
|
grid,
|
||||||
|
cluster,
|
||||||
|
fallback_cluster,
|
||||||
|
block,
|
||||||
|
smem_size,
|
||||||
|
stream,
|
||||||
|
kernel,
|
||||||
|
kernel_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|||||||
@ -48,7 +48,7 @@ namespace cutlass::detail{
|
|||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
|
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
|
||||||
struct Sm100BlockwiseScaleConfig {
|
struct Sm1xxBlockwiseScaleConfig {
|
||||||
|
|
||||||
using ShapeSFA = Shape<Shape<Int<SFVecSizeM>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
|
using ShapeSFA = Shape<Shape<Int<SFVecSizeM>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
|
||||||
using ShapeSFB = Shape<Shape<Int<SFVecSizeN>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
|
using ShapeSFB = Shape<Shape<Int<SFVecSizeN>, int32_t>, Shape<Int<SFVecSizeK>, int32_t>, int32_t>;
|
||||||
@ -271,7 +271,18 @@ struct RuntimeBlockwiseScaleConfig {
|
|||||||
|
|
||||||
// Sm90 only supports MN major for SFA and SFB for now
|
// Sm90 only supports MN major for SFA and SFB for now
|
||||||
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK>
|
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK>
|
||||||
using Sm90BlockwiseScaleConfig = Sm100BlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK>;
|
using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK>;
|
||||||
|
|
||||||
|
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
|
||||||
|
using Sm100BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK, majorSFA, majorSFB>;
|
||||||
|
|
||||||
|
template<int SFVecSizeM, int SFVecSizeN, int SFVecSizeK, UMMA::Major majorSFA = UMMA::Major::MN, UMMA::Major majorSFB = UMMA::Major::MN>
|
||||||
|
using Sm120BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig<SFVecSizeM, SFVecSizeN, SFVecSizeK, majorSFA, majorSFB>;
|
||||||
|
|
||||||
|
template<class MmaTileShape_MNK>
|
||||||
|
constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) {
|
||||||
|
return Sm90BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{};
|
||||||
|
}
|
||||||
|
|
||||||
template<class MmaTileShape_MNK>
|
template<class MmaTileShape_MNK>
|
||||||
constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) {
|
constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) {
|
||||||
@ -279,8 +290,8 @@ constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<class MmaTileShape_MNK>
|
template<class MmaTileShape_MNK>
|
||||||
constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) {
|
constexpr auto sm120_trivial_blockwise_scale_config(MmaTileShape_MNK) {
|
||||||
return Sm90BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{};
|
return Sm120BlockwiseScaleConfig<size<0>(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -371,11 +371,14 @@ template <
|
|||||||
constexpr int
|
constexpr int
|
||||||
get_input_alignment_bits() {
|
get_input_alignment_bits() {
|
||||||
if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 4) {
|
if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 4) {
|
||||||
|
// 16U4 format: The inner tensor size dimension should be multiple of 64B.
|
||||||
return 64 * 8;
|
return 64 * 8;
|
||||||
}
|
}
|
||||||
else if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 6) {
|
else if constexpr (IsF8F6F4SubBytes && sizeof_bits<ElementType>::value == 6) {
|
||||||
|
// 16U6 format : The inner tensor size dimension must be a multiple of 96B.
|
||||||
return 96 * 8;
|
return 96 * 8;
|
||||||
}
|
}
|
||||||
|
// TMA 16B alignment requirement
|
||||||
return 128;
|
return 128;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,12 +386,11 @@ get_input_alignment_bits() {
|
|||||||
template <class ElementType>
|
template <class ElementType>
|
||||||
constexpr int
|
constexpr int
|
||||||
get_output_alignment_bits() {
|
get_output_alignment_bits() {
|
||||||
|
|
||||||
if constexpr (sizeof_bits<ElementType>::value == 6) {
|
if constexpr (sizeof_bits<ElementType>::value == 6) {
|
||||||
// U6 format : The inner tensor size dimension must be a multiple of 96B.
|
// 16U6 format : The inner tensor size dimension must be a multiple of 96B.
|
||||||
return 96 * 8;
|
return 96 * 8;
|
||||||
}
|
}
|
||||||
|
// TMA 16B alignment requirement
|
||||||
return 128;
|
return 128;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -981,6 +981,9 @@ private:
|
|||||||
static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule>;
|
static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule>;
|
||||||
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
|
static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule");
|
||||||
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
|
static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch");
|
||||||
|
// C/D should meet TMA alignment requirement if not void
|
||||||
|
static_assert(detail::is_aligned<ElementC_, AlignmentC, ElementD_, AlignmentD>(),
|
||||||
|
"C/D Should meet TMA alignment requirement\n");
|
||||||
|
|
||||||
static constexpr bool DisableDestination = cute::is_void_v<ElementD_>;
|
static constexpr bool DisableDestination = cute::is_void_v<ElementD_>;
|
||||||
using ElementD = cute::conditional_t<DisableDestination,fusion::get_element_aux_t<FusionOpOrCallbacks>,ElementD_>; // prevents void ref breakages
|
using ElementD = cute::conditional_t<DisableDestination,fusion::get_element_aux_t<FusionOpOrCallbacks>,ElementD_>; // prevents void ref breakages
|
||||||
|
|||||||
@ -293,6 +293,9 @@ template <
|
|||||||
class DispatchPolicy
|
class DispatchPolicy
|
||||||
>
|
>
|
||||||
struct Sm90TmaBuilderImpl {
|
struct Sm90TmaBuilderImpl {
|
||||||
|
// C/D should meet TMA alignment requirement if not void
|
||||||
|
static_assert(detail::is_aligned<ElementC_, AlignmentC, ElementD_, AlignmentD>(),
|
||||||
|
"C/D Should meet TMA alignment requirement\n");
|
||||||
// Passing void D disables destination store + smem allocation
|
// Passing void D disables destination store + smem allocation
|
||||||
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
|
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
|
||||||
fusion::get_element_aux_t<FusionOpOrCallbacks>, ElementD_>;
|
fusion::get_element_aux_t<FusionOpOrCallbacks>, ElementD_>;
|
||||||
|
|||||||
@ -91,6 +91,14 @@ sm90_get_smem_load_op_for_source() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// C/D should meet TMA alignment requirement if not void
|
||||||
|
template <class ElementC, int AlignmentC, class ElementD, int AlignmentD>
|
||||||
|
constexpr bool
|
||||||
|
is_aligned() {
|
||||||
|
return (cute::is_void_v<ElementC> || (cute::sizeof_bits_v<ElementC> * AlignmentC) % cutlass::detail::get_output_alignment_bits<ElementC>() == 0) &&
|
||||||
|
(cute::is_void_v<ElementD> || (cute::sizeof_bits_v<ElementD> * AlignmentD) % cutlass::detail::get_output_alignment_bits<ElementD>() == 0);
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace cutlass::epilogue::collective::detail
|
} // namespace cutlass::epilogue::collective::detail
|
||||||
|
|||||||
@ -217,6 +217,16 @@ struct IsThreadEpilogueOpWithActivation <ThreadEpilogueOp, cute::enable_if_t<Thr
|
|||||||
using type = typename ThreadEpilogueOp::ActivationFn;
|
using type = typename ThreadEpilogueOp::ActivationFn;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename ThreadEpilogueOp, typename = void>
|
||||||
|
struct IsThreadEpilogueOpWithPerChannelScaled {
|
||||||
|
static constexpr bool value = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ThreadEpilogueOp>
|
||||||
|
struct IsThreadEpilogueOpWithPerChannelScaled <ThreadEpilogueOp, cute::void_t<decltype(ThreadEpilogueOp::IsPerRowScaleSupported)>> {
|
||||||
|
static constexpr bool value = ThreadEpilogueOp::IsPerRowScaleSupported || ThreadEpilogueOp::IsPerColScaleSupported;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ThreadEpilogueOp, typename = void>
|
template <typename ThreadEpilogueOp, typename = void>
|
||||||
struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {};
|
struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {};
|
||||||
|
|
||||||
|
|||||||
@ -57,7 +57,7 @@ struct IsDefaultFusionOp {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<
|
template<
|
||||||
class ElementD, class ElementCompute,
|
class ElementD, class ElementCompute,
|
||||||
class ElementC, FloatRoundStyle RoundStyle
|
class ElementC, FloatRoundStyle RoundStyle
|
||||||
>
|
>
|
||||||
struct IsDefaultFusionOp<
|
struct IsDefaultFusionOp<
|
||||||
@ -69,7 +69,7 @@ struct IsDefaultFusionOp<
|
|||||||
|
|
||||||
template<
|
template<
|
||||||
class ElementOutput, int Count, class ElementAccumulator,
|
class ElementOutput, int Count, class ElementAccumulator,
|
||||||
class ElementCompute, epilogue::thread::ScaleType::Kind Scale,
|
class ElementCompute, epilogue::thread::ScaleType::Kind Scale,
|
||||||
FloatRoundStyle Round, class ElementSource
|
FloatRoundStyle Round, class ElementSource
|
||||||
>
|
>
|
||||||
struct IsDefaultFusionOp<
|
struct IsDefaultFusionOp<
|
||||||
@ -133,7 +133,7 @@ public:
|
|||||||
constexpr static int ThreadCount = 128;
|
constexpr static int ThreadCount = 128;
|
||||||
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
|
constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||||
constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value;
|
constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias<ThreadEpilogueOp>::value;
|
||||||
|
|
||||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||||
constexpr static uint32_t TmaTransactionBytes = 0;
|
constexpr static uint32_t TmaTransactionBytes = 0;
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ public:
|
|||||||
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tTR_gD)); // (T2R,T2R_M,T2R_N)
|
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tTR_gD)); // (T2R,T2R_M,T2R_N)
|
||||||
|
|
||||||
Tensor tTR_rC = make_tensor<GmemElementC>(shape(tTR_gC)); // (T2R,T2R_M,T2R_N)
|
Tensor tTR_rC = make_tensor<GmemElementC>(shape(tTR_gC)); // (T2R,T2R_M,T2R_N)
|
||||||
|
|
||||||
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||||
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||||
Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l)
|
||||||
@ -250,7 +250,7 @@ public:
|
|||||||
Tensor tTR_rD_frag = make_tensor<ElementD>(shape(tTR_rAcc));
|
Tensor tTR_rD_frag = make_tensor<ElementD>(shape(tTR_rAcc));
|
||||||
Tensor tTR_rD_src = recast<Array<ElementD, VD>>(coalesce(tTR_rD_frag));
|
Tensor tTR_rD_src = recast<Array<ElementD, VD>>(coalesce(tTR_rD_frag));
|
||||||
Tensor tR2G_rD_dst = recast<Array<ElementD, VD>>(coalesce(tTR_gD));
|
Tensor tR2G_rD_dst = recast<Array<ElementD, VD>>(coalesce(tTR_gD));
|
||||||
|
|
||||||
Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int<VD>{})));
|
Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int<VD>{})));
|
||||||
Tensor tDpD = make_tensor<bool>(shape(tR2G_rD_dst));
|
Tensor tDpD = make_tensor<bool>(shape(tR2G_rD_dst));
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ public:
|
|||||||
copy_if(tDpD, tTR_rD_src, tR2G_rD_dst);
|
copy_if(tDpD, tTR_rD_src, tR2G_rD_dst);
|
||||||
}
|
}
|
||||||
// source is not needed, avoid load
|
// source is not needed, avoid load
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(tTR_rAcc); i++) {
|
for (int i = 0; i < size(tTR_rAcc); i++) {
|
||||||
@ -382,7 +382,7 @@ public:
|
|||||||
auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r));
|
auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r));
|
||||||
Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N)
|
Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N)
|
||||||
Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N)
|
Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N)
|
||||||
|
|
||||||
|
|
||||||
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l)
|
||||||
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l)
|
||||||
@ -498,7 +498,7 @@ public:
|
|||||||
// Constructor and Data Members
|
// Constructor and Data Members
|
||||||
//
|
//
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors)
|
CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors)
|
||||||
: fusion_callbacks(params_.thread, shared_tensors.thread)
|
: fusion_callbacks(params_.thread, shared_tensors.thread)
|
||||||
, smem_buffer_ptr(shared_tensors.buffer.data())
|
, smem_buffer_ptr(shared_tensors.buffer.data())
|
||||||
, params(params_) {};
|
, params(params_) {};
|
||||||
@ -506,7 +506,7 @@ public:
|
|||||||
protected:
|
protected:
|
||||||
FusionCallbacks fusion_callbacks;
|
FusionCallbacks fusion_callbacks;
|
||||||
uint8_t* smem_buffer_ptr;
|
uint8_t* smem_buffer_ptr;
|
||||||
Params const& params;
|
Params const& params;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -543,7 +543,7 @@ public:
|
|||||||
can_implement(
|
can_implement(
|
||||||
[[maybe_unused]] ProblemShape const& problem_shape,
|
[[maybe_unused]] ProblemShape const& problem_shape,
|
||||||
[[maybe_unused]] Arguments const& args) {
|
[[maybe_unused]] Arguments const& args) {
|
||||||
|
|
||||||
bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread);
|
bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread);
|
||||||
if (!fusion_implementable) {
|
if (!fusion_implementable) {
|
||||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
|
||||||
@ -636,7 +636,7 @@ public:
|
|||||||
Tensor tTR_rD_frg = recast<Array<ElementD, FragmentSize>>(coalesce(tTR_rD));
|
Tensor tTR_rD_frg = recast<Array<ElementD, FragmentSize>>(coalesce(tTR_rD));
|
||||||
|
|
||||||
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
|
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
|
||||||
problem_shape_mnkl,
|
problem_shape_mnkl,
|
||||||
cta_tile_mnk,
|
cta_tile_mnk,
|
||||||
cta_coord_mnkl,
|
cta_coord_mnkl,
|
||||||
int(0),
|
int(0),
|
||||||
@ -693,20 +693,17 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n);
|
Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n);
|
||||||
|
Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); });
|
||||||
cst_callbacks.begin_loop(epi_m, epi_n);
|
cst_callbacks.begin_loop(epi_m, epi_n);
|
||||||
|
|
||||||
if constexpr (not cute::is_void_v<ElementC>) {
|
if constexpr (not cute::is_void_v<ElementC>) {
|
||||||
if (is_C_load_needed) {
|
if (is_C_load_needed) {
|
||||||
using CVecType = uint_bit_t<VC * sizeof_bits_v<ElementC>>;
|
using CVecType = uint_bit_t<VC * sizeof_bits_v<ElementC>>;
|
||||||
Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int<VC>{})));
|
|
||||||
|
|
||||||
auto pred_fn_C = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
|
|
||||||
return elem_less(tTR_cC_frag(coords...), problem_shape_mnl);
|
|
||||||
};
|
|
||||||
|
|
||||||
Tensor tTR_gC_frg = recast<CVecType>(coalesce(tTR_gC(_,_,_,epi_m,epi_n)));
|
Tensor tTR_gC_frg = recast<CVecType>(coalesce(tTR_gC(_,_,_,epi_m,epi_n)));
|
||||||
Tensor tTR_rC_frg = recast<CVecType>(coalesce(tCrC));
|
Tensor tTR_rC_frg = recast<CVecType>(coalesce(tCrC));
|
||||||
copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg);
|
Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int<VC>{})));
|
||||||
|
copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -717,7 +714,7 @@ public:
|
|||||||
Tensor tTR_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
|
Tensor tTR_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
|
||||||
|
|
||||||
copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc);
|
copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc);
|
||||||
|
|
||||||
// After the last tmem load, signal that tmem buffer is consumed and empty
|
// After the last tmem load, signal that tmem buffer is consumed and empty
|
||||||
if (do_acc_release) {
|
if (do_acc_release) {
|
||||||
cutlass::arch::fence_view_async_tmem_load();
|
cutlass::arch::fence_view_async_tmem_load();
|
||||||
@ -737,16 +734,11 @@ public:
|
|||||||
|
|
||||||
cst_callbacks.end_loop(epi_m, epi_n);
|
cst_callbacks.end_loop(epi_m, epi_n);
|
||||||
|
|
||||||
|
|
||||||
Tensor tTR_cD_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclD.compose(Int<VD>{})));
|
|
||||||
auto pred_fn_D = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
|
|
||||||
return elem_less(tTR_cD_frag(coords...), problem_shape_mnl);
|
|
||||||
};
|
|
||||||
|
|
||||||
using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>;
|
using VecType = uint_bit_t<VD * sizeof_bits_v<ElementD>>;
|
||||||
Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n)));
|
Tensor tTR_gD_frg = recast<VecType>(coalesce(tTR_gD(_,_,_,epi_m,epi_n)));
|
||||||
Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD));
|
Tensor tTR_rD_frg = recast<VecType>(coalesce(tTR_rD));
|
||||||
copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg);
|
Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int<VD>{})));
|
||||||
|
copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg);
|
||||||
} // for epi_m
|
} // for epi_m
|
||||||
} // for epi_n
|
} // for epi_n
|
||||||
|
|
||||||
|
|||||||
@ -59,7 +59,7 @@ template <
|
|||||||
>
|
>
|
||||||
class Epilogue {
|
class Epilogue {
|
||||||
static_assert(cute::is_same_v<EpilogueScheduleType, EpilogueSimtVectorized> ||
|
static_assert(cute::is_same_v<EpilogueScheduleType, EpilogueSimtVectorized> ||
|
||||||
cute::is_same_v<EpilogueScheduleType, EpiloguePtrArraySimtVectorized>,
|
cute::is_same_v<EpilogueScheduleType, EpiloguePtrArraySimtVectorized>,
|
||||||
"Could not find an epilogue specialization.");
|
"Could not find an epilogue specialization.");
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ public:
|
|||||||
ElementScalar const* beta_ptr = nullptr;
|
ElementScalar const* beta_ptr = nullptr;
|
||||||
ElementBias const* bias_ptr = nullptr;
|
ElementBias const* bias_ptr = nullptr;
|
||||||
StrideBias dBias{};
|
StrideBias dBias{};
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class ThreadEpiOp>
|
template<class ThreadEpiOp>
|
||||||
struct ThreadEpilogueOpArguments<
|
struct ThreadEpilogueOpArguments<
|
||||||
@ -202,7 +202,7 @@ public:
|
|||||||
to_underlying_arguments(
|
to_underlying_arguments(
|
||||||
[[maybe_unused]] ProblemShape const& _,
|
[[maybe_unused]] ProblemShape const& _,
|
||||||
Arguments const& args,
|
Arguments const& args,
|
||||||
[[maybe_unused]] void* workspace) {
|
[[maybe_unused]] void* workspace) {
|
||||||
typename ThreadEpilogueOp::Params thread_op_args;
|
typename ThreadEpilogueOp::Params thread_op_args;
|
||||||
thread_op_args.alpha = args.thread.alpha;
|
thread_op_args.alpha = args.thread.alpha;
|
||||||
thread_op_args.beta = args.thread.beta;
|
thread_op_args.beta = args.thread.beta;
|
||||||
@ -317,7 +317,7 @@ public:
|
|||||||
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
||||||
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
||||||
Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
||||||
|
|
||||||
// Construct a tensor in SMEM that we can partition for rearranging data
|
// Construct a tensor in SMEM that we can partition for rearranging data
|
||||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||||
Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N)
|
Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N)
|
||||||
@ -389,10 +389,10 @@ public:
|
|||||||
Tensor tSR_gBias_flt = filter_zeros(tSR_gBias);
|
Tensor tSR_gBias_flt = filter_zeros(tSR_gBias);
|
||||||
Tensor tSR_rBias_flt = filter_zeros(tSR_rBias);
|
Tensor tSR_rBias_flt = filter_zeros(tSR_rBias);
|
||||||
Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride());
|
Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride());
|
||||||
|
Tensor tSR_pD_flt = cute::lazy::transform(tSR_cD_flt, [&](auto const& c){ return elem_less(c, take<0,2>(residue_mnk)); });
|
||||||
|
|
||||||
// Step 0. Copy Bias from GMEM to fragment
|
// Step 0. Copy Bias from GMEM to fragment
|
||||||
auto pred_fn = [&] (auto const&... coords) { return elem_less(tSR_cD_flt(coords...), take<0, 2>(residue_mnk)); };
|
copy_if(tSR_pD_flt, tSR_gBias_flt, tSR_rBias_flt);
|
||||||
copy_if(pred_fn, tSR_gBias_flt, tSR_rBias_flt);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -560,18 +560,18 @@ struct Sm90TreeVisitor<
|
|||||||
Tensor tC_rAux_vec = recast<VecType>(tC_rAux);
|
Tensor tC_rAux_vec = recast<VecType>(tC_rAux);
|
||||||
Tensor tC_gAux_vec = recast<VecType>(tC_gAux);
|
Tensor tC_gAux_vec = recast<VecType>(tC_gAux);
|
||||||
Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{})));
|
Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{})));
|
||||||
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); };
|
Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
|
||||||
copy_if(predicate_fn, tC_rAux_vec, tC_gAux_vec);
|
copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec);
|
||||||
}
|
}
|
||||||
// sub-byte vectorization, must serialize threads
|
// sub-byte vectorization, must serialize threads
|
||||||
else {
|
else {
|
||||||
// Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this)
|
// Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this)
|
||||||
int lane_idx = canonical_lane_idx();
|
int lane_idx = canonical_lane_idx();
|
||||||
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(coords...), residue_tC_cAux); };
|
Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
|
||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
for (int i = 0; i < NumThreadsPerWarp; ++i) {
|
for (int i = 0; i < NumThreadsPerWarp; ++i) {
|
||||||
if (lane_idx == i) {
|
if (lane_idx == i) {
|
||||||
copy_if(predicate_fn, tC_rAux, tC_gAux);
|
copy_if(tC_pAux, tC_rAux, tC_gAux);
|
||||||
}
|
}
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
}
|
}
|
||||||
@ -719,12 +719,12 @@ struct Sm90AuxLoad<
|
|||||||
Tensor tC_gAux_vec = recast<VecType>(tC_gAux);
|
Tensor tC_gAux_vec = recast<VecType>(tC_gAux);
|
||||||
Tensor tC_rAux_vec = recast<VecType>(tC_rAux);
|
Tensor tC_rAux_vec = recast<VecType>(tC_rAux);
|
||||||
Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{})));
|
Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int<V>{})));
|
||||||
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); };
|
Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
|
||||||
copy_if(predicate_fn, tC_gAux_vec, tC_rAux_vec);
|
copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(coords...), residue_tC_cAux); };
|
Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
|
||||||
copy_if(predicate_fn, tC_gAux, tC_rAux);
|
copy_if(tC_pAux, tC_gAux, tC_rAux);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -738,8 +738,8 @@ struct Sm90AuxLoad<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_tC_cAux); };
|
Tensor tC_pAux = cute::lazy::transform(tC_cAux(_,_,_,epi_m,epi_n), [&](auto const& c){ return elem_less(c, residue_tC_cAux); });
|
||||||
copy_if(predicate_fn, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux);
|
copy_if(tC_pAux, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -449,7 +449,7 @@ template <
|
|||||||
bool EnableNullptr
|
bool EnableNullptr
|
||||||
>
|
>
|
||||||
struct Sm90AuxLoad<
|
struct Sm90AuxLoad<
|
||||||
0, EpilogueTile, Element, LayoutOrStrideMNL,
|
0, EpilogueTile, Element, LayoutOrStrideMNL,
|
||||||
SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr
|
SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr
|
||||||
> {
|
> {
|
||||||
using ElementAux = Element;
|
using ElementAux = Element;
|
||||||
@ -496,7 +496,7 @@ struct Sm90AuxLoad<
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage)
|
Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage)
|
||||||
: params_ptr(¶ms) { }
|
: params_ptr(¶ms) { }
|
||||||
|
|
||||||
Params const* params_ptr;
|
Params const* params_ptr;
|
||||||
|
|
||||||
CUTLASS_DEVICE bool
|
CUTLASS_DEVICE bool
|
||||||
@ -533,7 +533,7 @@ struct Sm90AuxLoad<
|
|||||||
tC_cAux(cute::forward<CTensorG2R>(tC_cAux)),
|
tC_cAux(cute::forward<CTensorG2R>(tC_cAux)),
|
||||||
problem_shape_mnl(problem_shape_mnl),
|
problem_shape_mnl(problem_shape_mnl),
|
||||||
params_ptr(params_ptr) {}
|
params_ptr(params_ptr) {}
|
||||||
|
|
||||||
GTensorG2R tC_gAux;
|
GTensorG2R tC_gAux;
|
||||||
RTensor tC_rAux;
|
RTensor tC_rAux;
|
||||||
CTensorG2R tC_cAux;
|
CTensorG2R tC_cAux;
|
||||||
@ -551,17 +551,13 @@ struct Sm90AuxLoad<
|
|||||||
constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){};
|
constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){};
|
||||||
constexpr int V = cute::min(Alignment, size(MCL));
|
constexpr int V = cute::min(Alignment, size(MCL));
|
||||||
|
|
||||||
Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n);
|
|
||||||
Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int<V>{})));
|
|
||||||
|
|
||||||
Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n)));
|
Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n)));
|
||||||
Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux));
|
Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux));
|
||||||
|
|
||||||
auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE {
|
Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int<V>{})));
|
||||||
return elem_less(tC_cAux_vec(coords...), problem_shape_mnl);
|
Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); });
|
||||||
};
|
|
||||||
|
|
||||||
copy_if(pred_fn, tC_gAux_vec, tC_rAux_vec);
|
copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ElementAccumulator, int FragmentSize>
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
@ -647,7 +643,7 @@ struct Sm90ScalarBroadcast {
|
|||||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ProblemShape>
|
template <class ProblemShape>
|
||||||
static size_t
|
static size_t
|
||||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
@ -674,11 +670,11 @@ struct Sm90ScalarBroadcast {
|
|||||||
// This must be called after update_scalar is called
|
// This must be called after update_scalar is called
|
||||||
CUTLASS_DEVICE bool
|
CUTLASS_DEVICE bool
|
||||||
is_zero() const {
|
is_zero() const {
|
||||||
if (get<2>(params_ptr->dScalar[0]) == 0) {
|
if (get<2>(params_ptr->dScalar[0]) == 0) {
|
||||||
// Only 1 batch
|
// Only 1 batch
|
||||||
return scalar == Element(0);
|
return scalar == Element(0);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// multiple batch
|
// multiple batch
|
||||||
if (valid_scalar == false) {
|
if (valid_scalar == false) {
|
||||||
// for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps.
|
// for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps.
|
||||||
@ -761,7 +757,7 @@ private:
|
|||||||
|
|
||||||
if (params_ptr->scalar_ptrs[0] != nullptr) {
|
if (params_ptr->scalar_ptrs[0] != nullptr) {
|
||||||
scalar = params_ptr->scalar_ptrs[0][l_offset];
|
scalar = params_ptr->scalar_ptrs[0][l_offset];
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// batch stride is ignored for nullptr fallback
|
// batch stride is ignored for nullptr fallback
|
||||||
scalar = params_ptr->scalars[0];
|
scalar = params_ptr->scalars[0];
|
||||||
@ -774,7 +770,7 @@ private:
|
|||||||
if (params_ptr->scalar_ptrs[i] != nullptr) {
|
if (params_ptr->scalar_ptrs[i] != nullptr) {
|
||||||
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
|
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
|
||||||
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
|
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// batch stride is ignored for nullptr fallback
|
// batch stride is ignored for nullptr fallback
|
||||||
scalar = reduction_fn(scalar, params_ptr->scalars[i]);
|
scalar = reduction_fn(scalar, params_ptr->scalars[i]);
|
||||||
@ -826,7 +822,7 @@ struct Sm90ScalarBroadcastPtrArray {
|
|||||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ProblemShape>
|
template <class ProblemShape>
|
||||||
static size_t
|
static size_t
|
||||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
@ -946,7 +942,7 @@ private:
|
|||||||
if (params_ptr->scalar_ptrs[i] != nullptr) {
|
if (params_ptr->scalar_ptrs[i] != nullptr) {
|
||||||
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
|
int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]);
|
||||||
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
|
scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// batch stride is ignored for nullptr fallback
|
// batch stride is ignored for nullptr fallback
|
||||||
scalar = reduction_fn(scalar, params_ptr->scalars[i]);
|
scalar = reduction_fn(scalar, params_ptr->scalars[i]);
|
||||||
@ -992,7 +988,7 @@ struct Sm90RowBroadcast {
|
|||||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static
|
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static
|
||||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast);
|
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast);
|
||||||
|
|
||||||
struct SharedStorage {
|
struct SharedStorage {
|
||||||
array_aligned<ElementInput, size<1>(CtaTileShapeMNK{})> smem;
|
array_aligned<ElementInput, size<1>(CtaTileShapeMNK{})> smem;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1078,8 +1074,8 @@ struct Sm90RowBroadcast {
|
|||||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
ConsumerStoreCallbacks(
|
ConsumerStoreCallbacks(
|
||||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||||
Residue residue_cRow_, Params const& params_)
|
Residue residue_cRow_, Params const& params_)
|
||||||
: tGS_gRow(tGS_gRow_)
|
: tGS_gRow(tGS_gRow_)
|
||||||
@ -1098,8 +1094,8 @@ struct Sm90RowBroadcast {
|
|||||||
Tiled_G2S tiled_G2S;
|
Tiled_G2S tiled_G2S;
|
||||||
|
|
||||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
Residue residue_cRow; // (m, n)
|
Residue residue_cRow; // (m, n)
|
||||||
Params const& params;
|
Params const& params;
|
||||||
|
|
||||||
@ -1113,7 +1109,7 @@ struct Sm90RowBroadcast {
|
|||||||
|
|
||||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||||
continue; // OOB of SMEM,
|
continue; // OOB of SMEM,
|
||||||
}
|
}
|
||||||
if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) {
|
if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) {
|
||||||
tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load
|
tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load
|
||||||
@ -1201,18 +1197,18 @@ struct Sm90RowBroadcast {
|
|||||||
}
|
}
|
||||||
Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L));
|
Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L));
|
||||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||||
//// G2S: Gmem to Smem
|
//// G2S: Gmem to Smem
|
||||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, ElementInput>{},
|
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, ElementInput>{},
|
||||||
Layout< Shape<_1, ThreadCount>,
|
Layout< Shape<_1, ThreadCount>,
|
||||||
Stride<_0, _1>>{},
|
Stride<_0, _1>>{},
|
||||||
Layout<_1>{});
|
Layout<_1>{});
|
||||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||||
|
|
||||||
//// G2S: Coord
|
//// G2S: Coord
|
||||||
Tensor tGS_cRow = thr_g2s.partition_S(args.cD);
|
Tensor tGS_cRow = thr_g2s.partition_S(args.cD);
|
||||||
|
|
||||||
//// S2R: Smem to Reg
|
//// S2R: Smem to Reg
|
||||||
@ -1220,11 +1216,11 @@ struct Sm90RowBroadcast {
|
|||||||
Tensor tSR_rRow = make_tensor_like<ElementCompute>(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
Tensor tSR_rRow = make_tensor_like<ElementCompute>(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||||
|
|
||||||
return ConsumerStoreCallbacks(
|
return ConsumerStoreCallbacks(
|
||||||
tGS_gRow,
|
tGS_gRow,
|
||||||
tGS_sRow,
|
tGS_sRow,
|
||||||
tGS_cRow, tiled_g2s,
|
tGS_cRow, tiled_g2s,
|
||||||
tSR_sRow,
|
tSR_sRow,
|
||||||
tSR_rRow,
|
tSR_rRow,
|
||||||
args.residue_cD,
|
args.residue_cD,
|
||||||
params);
|
params);
|
||||||
}
|
}
|
||||||
@ -1378,12 +1374,12 @@ struct Sm90ColBroadcast {
|
|||||||
Tensor tCgCol_vec = recast<VecType>(coalesce(tCgCol_flt));
|
Tensor tCgCol_vec = recast<VecType>(coalesce(tCgCol_flt));
|
||||||
Tensor tCrCol_vec = recast<VecType>(coalesce(tCrCol_flt));
|
Tensor tCrCol_vec = recast<VecType>(coalesce(tCrCol_flt));
|
||||||
Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int<V>{})));
|
Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int<V>{})));
|
||||||
auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tCcCol_vec(coords...), residue_tCcCol); };
|
Tensor tCpCol_vec = cute::lazy::transform(tCcCol_vec, [&](auto const& c){ return elem_less(c, residue_tCcCol); });
|
||||||
copy_if(pred_fn, tCgCol_vec, tCrCol_vec);
|
copy_if(tCpCol_vec, tCgCol_vec, tCrCol_vec);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tCcCol_flt(coords...), residue_tCcCol); };
|
Tensor tCpCol_flt = cute::lazy::transform(tCcCol_flt, [&](auto const& c){ return elem_less(c, residue_tCcCol); });
|
||||||
copy_if(pred_fn, tCgCol_flt, tCrCol_flt);
|
copy_if(tCpCol_flt, tCgCol_flt, tCrCol_flt);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int FrgSize = size(tCrCol_flt);
|
constexpr int FrgSize = size(tCrCol_flt);
|
||||||
|
|||||||
@ -412,17 +412,13 @@ struct Sm90AuxStore<
|
|||||||
constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){};
|
constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){};
|
||||||
constexpr int V = cute::min(Alignment, size(MCL));
|
constexpr int V = cute::min(Alignment, size(MCL));
|
||||||
|
|
||||||
Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n);
|
|
||||||
Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int<V>{})));
|
|
||||||
|
|
||||||
Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n)));
|
Tensor tC_gAux_vec = recast<Array<Element, V>>(coalesce(tC_gAux(_,_,_,epi_m,epi_n)));
|
||||||
Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux));
|
Tensor tC_rAux_vec = recast<Array<Element, V>>(coalesce(tC_rAux));
|
||||||
|
|
||||||
auto pred_fn = [&] (auto const&... coords) {
|
Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int<V>{})));
|
||||||
return elem_less(tC_cAux_vec(coords...), problem_shape_mnl);
|
Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); });
|
||||||
};
|
|
||||||
|
|
||||||
copy_if(pred_fn, tC_rAux_vec, tC_gAux_vec);
|
copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -540,6 +540,23 @@ struct HardSwish<Array<T, N> > {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
struct HardSwish<Array<half_t, N> > {
|
||||||
|
using T = half_t;
|
||||||
|
static const bool kIsHeavy = false;
|
||||||
|
static constexpr float kOneSixth = 0.16666667f;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
|
minimum<Array<T, N> > mn;
|
||||||
|
maximum<Array<T, N> > mx;
|
||||||
|
multiplies<Array<T, N> > mul;
|
||||||
|
plus<Array<T, N> > add;
|
||||||
|
|
||||||
|
return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(kOneSixth));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using ScaledHardSwish = Scale<HardSwish<T>>;
|
using ScaledHardSwish = Scale<HardSwish<T>>;
|
||||||
|
|
||||||
|
|||||||
@ -542,12 +542,8 @@ struct VisitorColBroadcast {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
clear(tC_rCol);
|
clear(tC_rCol);
|
||||||
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
Tensor tC_pCol = cute::lazy::transform(tC_cCol, [&] (auto const& c) { return get<0>(c) < m; });
|
||||||
CUTLASS_PRAGMA_UNROLL
|
copy_if(tC_pCol, tC_gCol, tC_rCol);
|
||||||
for (int i = 0; i < size(pred); ++i) {
|
|
||||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
|
||||||
}
|
|
||||||
copy_if(pred, tC_gCol, tC_rCol);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ElementAccumulator, int FragmentSize>
|
template <class ElementAccumulator, int FragmentSize>
|
||||||
|
|||||||
@ -446,7 +446,7 @@ public:
|
|||||||
|
|
||||||
Status
|
Status
|
||||||
construct_graph(bool launch_with_pdl) {
|
construct_graph(bool launch_with_pdl) {
|
||||||
#if ((__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4))
|
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||||
Status status = Status::kSuccess;
|
Status status = Status::kSuccess;
|
||||||
|
|
||||||
// Destroy existing graph, if created
|
// Destroy existing graph, if created
|
||||||
|
|||||||
@ -47,7 +47,7 @@ void launch_full_barrier(
|
|||||||
cudaStream_t stream,
|
cudaStream_t stream,
|
||||||
bool launch_with_pdl) {
|
bool launch_with_pdl) {
|
||||||
|
|
||||||
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4))
|
#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))
|
||||||
// Legacy (kernel) launch with PDL
|
// Legacy (kernel) launch with PDL
|
||||||
cudaLaunchAttribute attributes[1];
|
cudaLaunchAttribute attributes[1];
|
||||||
attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
|||||||
@ -268,7 +268,7 @@ struct CollectiveBuilder<
|
|||||||
// Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages.
|
// Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages.
|
||||||
static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{});
|
static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{});
|
||||||
static constexpr uint32_t AccumulatorPipelineStageCount = (AccumulatorNPerCta == 256) ? 1 : 2;
|
static constexpr uint32_t AccumulatorPipelineStageCount = (AccumulatorNPerCta == 256) ? 1 : 2;
|
||||||
static constexpr uint32_t SchedulerPipelineStageCount = 1;
|
static constexpr uint32_t SchedulerPipelineStageCount = 2;
|
||||||
|
|
||||||
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;
|
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;
|
||||||
|
|
||||||
|
|||||||
@ -238,7 +238,7 @@ struct CollectiveBuilder<
|
|||||||
static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>;
|
static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>;
|
||||||
// Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler.
|
// Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler.
|
||||||
static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>;
|
static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>;
|
||||||
static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 1);
|
static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 2);
|
||||||
|
|
||||||
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
|
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
|
||||||
ClusterShape_MNK,
|
ClusterShape_MNK,
|
||||||
|
|||||||
@ -51,6 +51,8 @@ struct Sm100DenseGemmTmaUmmaCarveout {
|
|||||||
static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage);
|
static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage);
|
||||||
// CLC (scheduler) response
|
// CLC (scheduler) response
|
||||||
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
|
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
|
||||||
|
// CLC Throttle pipeline storage
|
||||||
|
static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
|
||||||
// Tmem dealloc
|
// Tmem dealloc
|
||||||
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
||||||
// Tmem ptr storage
|
// Tmem ptr storage
|
||||||
@ -64,6 +66,7 @@ struct Sm100DenseGemmTmaUmmaCarveout {
|
|||||||
CLCPipelineStorage +
|
CLCPipelineStorage +
|
||||||
LoadOrderBarrierStorage +
|
LoadOrderBarrierStorage +
|
||||||
TmemDeallocStorage +
|
TmemDeallocStorage +
|
||||||
|
CLCThrottlePipelineStorage +
|
||||||
CLCResponseStorage +
|
CLCResponseStorage +
|
||||||
TmemBasePtrsStorage +
|
TmemBasePtrsStorage +
|
||||||
TensorMapStorage
|
TensorMapStorage
|
||||||
@ -80,6 +83,8 @@ struct Sm100SparseGemmTmaUmmaCarveout {
|
|||||||
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
||||||
// AccumulatorPipeline = PipelineUmmaAsync
|
// AccumulatorPipeline = PipelineUmmaAsync
|
||||||
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
|
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
|
||||||
|
// CLC Throttle pipeline storage
|
||||||
|
static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
|
||||||
// Tmem dealloc
|
// Tmem dealloc
|
||||||
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
|
||||||
|
|
||||||
@ -87,6 +92,7 @@ struct Sm100SparseGemmTmaUmmaCarveout {
|
|||||||
cutlass::round_up(LoadOrderBarrierStorage, 16) +
|
cutlass::round_up(LoadOrderBarrierStorage, 16) +
|
||||||
cutlass::round_up(CLCPipelineStorage, 16) +
|
cutlass::round_up(CLCPipelineStorage, 16) +
|
||||||
cutlass::round_up(AccumulatorPipelineStorage, 16) +
|
cutlass::round_up(AccumulatorPipelineStorage, 16) +
|
||||||
|
cutlass::round_up(CLCThrottlePipelineStorage, 16) +
|
||||||
cutlass::round_up(TmemDeallocStorage, 16),
|
cutlass::round_up(TmemDeallocStorage, 16),
|
||||||
16));
|
16));
|
||||||
|
|
||||||
|
|||||||
@ -371,7 +371,7 @@ struct CollectiveBuilder<
|
|||||||
// Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages.
|
// Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages.
|
||||||
static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{});
|
static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{});
|
||||||
static constexpr uint32_t AccumulatorPipelineStageCount = AccumulatorNPerCta > 224 ? 1 : 2;
|
static constexpr uint32_t AccumulatorPipelineStageCount = AccumulatorNPerCta > 224 ? 1 : 2;
|
||||||
static constexpr uint32_t SchedulerPipelineStageCount = 1;
|
static constexpr uint32_t SchedulerPipelineStageCount = 2;
|
||||||
|
|
||||||
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;
|
using SmemTileShape = cute::Shape<BlockTileA_M, BlockTileB_N, BlockTileA_K>;
|
||||||
|
|
||||||
|
|||||||
@ -267,7 +267,7 @@ struct CollectiveBuilder<
|
|||||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v<KernelScheduleSm100PtrArrayDenseGemm, BuilderScheduleTag>);
|
static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v<KernelScheduleSm100PtrArrayDenseGemm, BuilderScheduleTag>);
|
||||||
// Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler.
|
// Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler.
|
||||||
static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>;
|
static constexpr bool IsGroupGemm = !cute::is_same_v<StrideA, InternalStrideA>;
|
||||||
static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 1);
|
static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return<IsGroupGemm>(8, 2);
|
||||||
|
|
||||||
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
|
static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout<
|
||||||
ClusterShape_MNK,
|
ClusterShape_MNK,
|
||||||
|
|||||||
@ -0,0 +1,264 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/gemm/collective/builders/sm120_common.inl"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass::gemm::collective {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||||
|
template <
|
||||||
|
int CapacityBytes,
|
||||||
|
class ElementA,
|
||||||
|
class ElementB,
|
||||||
|
class ElementScalar,
|
||||||
|
class TileShapeMNK,
|
||||||
|
class ScaleShapeMNK,
|
||||||
|
class MainloopPipelineStorage,
|
||||||
|
int stages
|
||||||
|
>
|
||||||
|
constexpr int
|
||||||
|
sm120_compute_stage_count_or_override_blockwise(StageCount<stages> stage_count) {
|
||||||
|
return stages;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the maximum number of smem tiles that can be used with a given smem capacity.
|
||||||
|
template <
|
||||||
|
int CapacityBytes,
|
||||||
|
class ElementA,
|
||||||
|
class ElementB,
|
||||||
|
class ElementScalar,
|
||||||
|
class TileShapeMNK,
|
||||||
|
class ScaleShapeMNK,
|
||||||
|
class MainloopPipelineStorage,
|
||||||
|
int carveout_bytes
|
||||||
|
>
|
||||||
|
constexpr auto
|
||||||
|
sm120_compute_stage_count_or_override_blockwise(StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||||
|
// For F6/F4 sub-bytes, ElementA/B will be passed in as uint8_t
|
||||||
|
|
||||||
|
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
|
||||||
|
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
|
||||||
|
constexpr auto scale_bits = cute::sizeof_bits_v<ElementScalar>;
|
||||||
|
constexpr auto mainloop_pipeline_bytes = sizeof(MainloopPipelineStorage);
|
||||||
|
|
||||||
|
constexpr int stage_bytes =
|
||||||
|
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||||
|
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||||
|
cutlass::bits_to_bytes(scale_bits * size<0>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
|
||||||
|
cutlass::bits_to_bytes(scale_bits * size<1>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) +
|
||||||
|
static_cast<int>(mainloop_pipeline_bytes);
|
||||||
|
|
||||||
|
|
||||||
|
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
template <
|
||||||
|
class ElementA,
|
||||||
|
class GmemLayoutATagPair,
|
||||||
|
int AlignmentA,
|
||||||
|
class ElementB,
|
||||||
|
class GmemLayoutBTagPair,
|
||||||
|
int AlignmentB,
|
||||||
|
class ElementAccumulator,
|
||||||
|
class TileShape_MNK,
|
||||||
|
class ClusterShape_MNK,
|
||||||
|
class StageCountType,
|
||||||
|
class BuilderScheduleTag
|
||||||
|
>
|
||||||
|
struct CollectiveBuilder<
|
||||||
|
arch::Sm120,
|
||||||
|
arch::OpClassTensorOp,
|
||||||
|
ElementA,
|
||||||
|
GmemLayoutATagPair,
|
||||||
|
AlignmentA,
|
||||||
|
ElementB,
|
||||||
|
GmemLayoutBTagPair,
|
||||||
|
AlignmentB,
|
||||||
|
ElementAccumulator,
|
||||||
|
TileShape_MNK,
|
||||||
|
ClusterShape_MNK,
|
||||||
|
StageCountType,
|
||||||
|
BuilderScheduleTag,
|
||||||
|
cute::enable_if_t<
|
||||||
|
not cute::is_tuple_v<ElementA> && not cute::is_tuple_v<ElementB> &&
|
||||||
|
not cute::is_complex_v<ElementA> && not cute::is_complex_v<ElementB> &&
|
||||||
|
cute::is_tuple_v<GmemLayoutATagPair> && cute::is_tuple_v<GmemLayoutBTagPair> &&
|
||||||
|
(cute::is_base_of_v<KernelScheduleSm120Blockwise, BuilderScheduleTag> ||
|
||||||
|
cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>) &&
|
||||||
|
detail::sm1xx_gemm_is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, BuilderScheduleTag>()>>
|
||||||
|
{
|
||||||
|
|
||||||
|
static_assert(detail::is_sm10x_f8f6f4_element<ElementA>() && detail::is_sm10x_f8f6f4_element<ElementB>(),
|
||||||
|
"SM120 TmaWarpSpecialized blockwise scaling builder currently only supports F8F6F4 MMA.");
|
||||||
|
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
|
||||||
|
static_assert(cute::is_static_v<ClusterShape_MNK>, "Cluster has to be static");
|
||||||
|
|
||||||
|
using GmemLayoutATag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutATagPair{}))>;
|
||||||
|
using GmemLayoutSFATag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutATagPair{}))>;
|
||||||
|
using GmemLayoutBTag = cute::remove_cvref_t<decltype(get<0>(GmemLayoutBTagPair{}))>;
|
||||||
|
using GmemLayoutSFBTag = cute::remove_cvref_t<decltype(get<1>(GmemLayoutBTagPair{}))>;
|
||||||
|
|
||||||
|
static_assert(cute::depth(cute::remove_pointer_t<GmemLayoutSFATag>{}) == 2 and
|
||||||
|
cute::depth(cute::remove_pointer_t<GmemLayoutSFBTag>{}) == 2,
|
||||||
|
"Expect SFA and SFB layout to be depth of two with shape ((SFVecMN, restMN),(SFVecK, restK), L)");
|
||||||
|
static_assert(size<1, 0>(cute::remove_pointer_t<GmemLayoutSFATag>{}) ==
|
||||||
|
size<1, 0>(cute::remove_pointer_t<GmemLayoutSFBTag>{}),
|
||||||
|
"SFA and SFB must have equivalent SF vector sizes along K");
|
||||||
|
|
||||||
|
static constexpr cute::UMMA::Major UmmaMajorA = detail::tag_to_umma_major_A<GmemLayoutATag>();
|
||||||
|
static constexpr cute::UMMA::Major UmmaMajorB = detail::tag_to_umma_major_B<GmemLayoutBTag>();
|
||||||
|
static_assert((UmmaMajorA == UMMA::Major::K && UmmaMajorB == UMMA::Major::K), "Only TN layout is supported.");
|
||||||
|
|
||||||
|
using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{}));
|
||||||
|
using PermTileN = decltype(cute::min(size<1>(TileShape_MNK{}), _32{}));
|
||||||
|
|
||||||
|
static constexpr bool IsCooperative = !cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag>;
|
||||||
|
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||||
|
Layout<Shape<_4,_2,_1>>, Layout<Shape<_2,_2,_1>>>;
|
||||||
|
|
||||||
|
// Data type used by MMA instruction
|
||||||
|
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementA>());
|
||||||
|
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementB>());
|
||||||
|
|
||||||
|
static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement<ElementAMma, ElementBMma,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
GmemLayoutATag, GmemLayoutBTag, false /*IsSparse*/>(),
|
||||||
|
"TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" );
|
||||||
|
|
||||||
|
// Setup TiledMma
|
||||||
|
using TiledMma = decltype(cute::make_tiled_mma(
|
||||||
|
cute::rr_op_selector_sm120<ElementA, ElementB, ElementAccumulator>(),
|
||||||
|
AtomLayoutMNK{},
|
||||||
|
Tile<PermTileM, PermTileN, _32>{}
|
||||||
|
));
|
||||||
|
|
||||||
|
// DType check
|
||||||
|
static constexpr bool UseF8f6f4 = detail::is_sm120_f8f6f4<TiledMma, ElementA, ElementB>();
|
||||||
|
static_assert(UseF8f6f4, "Non-blockscaled collective builder only supports F8F6F4 MMA.\n");
|
||||||
|
|
||||||
|
// Element type
|
||||||
|
using SmemAllocTypeA = cute::conditional_t<UseF8f6f4, uint8_t, typename TiledMma::ValTypeA>;
|
||||||
|
using SmemAllocTypeB = cute::conditional_t<UseF8f6f4, uint8_t, typename TiledMma::ValTypeB>;
|
||||||
|
|
||||||
|
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||||
|
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||||
|
|
||||||
|
using SmemLayoutAtomA = decltype(detail::sm120_rr_smem_selector<SmemAllocTypeA, decltype(size<2>(TileShape_MNK{}))>());
|
||||||
|
using SmemLayoutAtomB = decltype(detail::sm120_rr_smem_selector<SmemAllocTypeB, decltype(size<2>(TileShape_MNK{}))>());
|
||||||
|
|
||||||
|
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
|
||||||
|
using StrideB = cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>;
|
||||||
|
using StrideSFA = cutlass::gemm::TagToStrideA_t<GmemLayoutSFATag>;
|
||||||
|
using StrideSFB = cutlass::gemm::TagToStrideB_t<GmemLayoutSFBTag>;
|
||||||
|
|
||||||
|
static constexpr int ScaleGranularityM = size<0,0>(cute::remove_pointer_t<GmemLayoutSFATag>{});
|
||||||
|
static constexpr int ScaleGranularityN = size<0,0>(cute::remove_pointer_t<GmemLayoutSFBTag>{});
|
||||||
|
static constexpr int ScaleGranularityK = size<1,0>(cute::remove_pointer_t<GmemLayoutSFBTag>{});
|
||||||
|
|
||||||
|
static_assert(size<0>(TileShape_MNK{}) % ScaleGranularityM == 0, "Scale Granularity M must evenly divide the tile shape M.");
|
||||||
|
static_assert(size<1>(TileShape_MNK{}) % ScaleGranularityN == 0, "Scale Granularity N must evenly divide the tile shape N.");
|
||||||
|
static_assert(size<2>(TileShape_MNK{}) == ScaleGranularityK , "Scale Granularity K must be equal to the tile shape K.");
|
||||||
|
|
||||||
|
using BlockTileScale_M = Int<size<0>(TileShape_MNK{}) / ScaleGranularityM>;
|
||||||
|
using BlockTileScale_N = Int<size<1>(TileShape_MNK{}) / ScaleGranularityN>;
|
||||||
|
using BlockTileScale_K = Int<size<2>(TileShape_MNK{}) / ScaleGranularityK>;
|
||||||
|
|
||||||
|
using ScaleTileShape = cute::Shape<BlockTileScale_M, BlockTileScale_N, BlockTileScale_K>;
|
||||||
|
|
||||||
|
|
||||||
|
// Setup Stages and DispatchPolicy
|
||||||
|
using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage;
|
||||||
|
|
||||||
|
static constexpr int PipelineStages = detail::sm120_compute_stage_count_or_override_blockwise<
|
||||||
|
detail::sm120_smem_capacity_bytes, SmemAllocTypeA,
|
||||||
|
SmemAllocTypeB, ElementAccumulator,
|
||||||
|
TileShape_MNK, ScaleTileShape, MainloopPipelineStorage>(StageCountType{});
|
||||||
|
static constexpr uint32_t SchedulerPipelineStageCount = 2;
|
||||||
|
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<cute::remove_pointer_t<StrideA>, StrideA>;
|
||||||
|
using KernelSchedule = cute::conditional_t<IsGroupedGemmKernel,
|
||||||
|
// PtrArray
|
||||||
|
cute::conditional_t<IsCooperative,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedCooperativeBlockwiseScalingSm120<SchedulerPipelineStageCount>,
|
||||||
|
KernelPtrArrayTmaWarpSpecializedPingpongBlockwiseScalingSm120<SchedulerPipelineStageCount>>,
|
||||||
|
// Non-PtrArray
|
||||||
|
cute::conditional_t<IsCooperative,
|
||||||
|
KernelTmaWarpSpecializedCooperativeBlockwiseScalingSm120<SchedulerPipelineStageCount>,
|
||||||
|
KernelTmaWarpSpecializedPingpongBlockwiseScalingSm120<SchedulerPipelineStageCount>>>;
|
||||||
|
|
||||||
|
using DispatchPolicy = cute::conditional_t<IsGroupedGemmKernel,
|
||||||
|
MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling<PipelineStages,
|
||||||
|
SchedulerPipelineStageCount,
|
||||||
|
ClusterShape_MNK,
|
||||||
|
KernelSchedule>,
|
||||||
|
MainloopSm120TmaWarpSpecializedBlockwiseScaling<PipelineStages,
|
||||||
|
SchedulerPipelineStageCount,
|
||||||
|
ClusterShape_MNK,
|
||||||
|
KernelSchedule>>;
|
||||||
|
|
||||||
|
using SmemCopyAtomA = Copy_Atom<decltype(detail::sm120_rr_smem_copy_selector_A<ElementA, ElementB, UseF8f6f4>()), SmemAllocTypeA>;
|
||||||
|
using SmemCopyAtomB = Copy_Atom<decltype(detail::sm120_rr_smem_copy_selector_B<ElementA, ElementB, UseF8f6f4>()), SmemAllocTypeB>;
|
||||||
|
|
||||||
|
|
||||||
|
using CollectiveOp = CollectiveMma<
|
||||||
|
DispatchPolicy,
|
||||||
|
TileShape_MNK,
|
||||||
|
ElementA,
|
||||||
|
cute::tuple<StrideA, StrideSFA>,
|
||||||
|
ElementB,
|
||||||
|
cute::tuple<StrideB, StrideSFB>,
|
||||||
|
TiledMma,
|
||||||
|
GmemTiledCopyA,
|
||||||
|
SmemLayoutAtomA,
|
||||||
|
SmemCopyAtomA,
|
||||||
|
cute::identity,
|
||||||
|
GmemTiledCopyB,
|
||||||
|
SmemLayoutAtomB,
|
||||||
|
SmemCopyAtomB,
|
||||||
|
cute::identity
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cutlass::gemm::collective
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -66,6 +66,8 @@ struct CollectiveBuilder<
|
|||||||
StageCountType,
|
StageCountType,
|
||||||
BuilderScheduleTag,
|
BuilderScheduleTag,
|
||||||
cute::enable_if_t<
|
cute::enable_if_t<
|
||||||
|
not cute::is_tuple_v<ElementA> && not cute::is_tuple_v<ElementB> &&
|
||||||
|
not cute::is_tuple_v<GmemLayoutATag> && not cute::is_tuple_v<GmemLayoutBTag> &&
|
||||||
// Dense Gemm
|
// Dense Gemm
|
||||||
(cute::is_base_of_v<KernelScheduleSm120DenseGemm, BuilderScheduleTag> ||
|
(cute::is_base_of_v<KernelScheduleSm120DenseGemm, BuilderScheduleTag> ||
|
||||||
cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag> ||
|
cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag> ||
|
||||||
|
|||||||
@ -50,6 +50,7 @@
|
|||||||
#include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl"
|
#include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl"
|
||||||
#include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl"
|
#include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl"
|
||||||
#include "cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl"
|
#include "cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl"
|
||||||
|
#include "cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -67,6 +67,8 @@
|
|||||||
#include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp"
|
#include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp"
|
||||||
#include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp"
|
#include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp"
|
||||||
#include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp"
|
#include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp"
|
||||||
|
#include "cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp"
|
||||||
|
#include "cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp"
|
||||||
#endif // !defined(__CUDACC_RTC__)
|
#endif // !defined(__CUDACC_RTC__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,10 +28,6 @@
|
|||||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*
|
*
|
||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
@ -51,7 +47,6 @@
|
|||||||
#include "cute/arch/cluster_sm90.hpp"
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
#include "cute/atom/mma_atom.hpp"
|
#include "cute/atom/mma_atom.hpp"
|
||||||
#include "cute/algorithm/gemm.hpp"
|
#include "cute/algorithm/gemm.hpp"
|
||||||
#include "cute/tensor_predicate.hpp"
|
|
||||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -169,7 +164,6 @@ struct CollectiveMma<
|
|||||||
using InternalStrideB = cute::remove_pointer_t<StrideB>;
|
using InternalStrideB = cute::remove_pointer_t<StrideB>;
|
||||||
|
|
||||||
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
||||||
|
|
||||||
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
|
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
|
||||||
|
|
||||||
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
|
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
|
||||||
@ -210,19 +204,15 @@ struct CollectiveMma<
|
|||||||
AtomThrShapeMNK>;
|
AtomThrShapeMNK>;
|
||||||
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
|
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
|
||||||
|
|
||||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)");
|
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||||
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
|
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape.");
|
||||||
"SmemLayoutAtom must evenly divide tile shape.");
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape.");
|
||||||
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
|
|
||||||
"SmemLayoutAtom must evenly divide tile shape.");
|
|
||||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||||
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
||||||
|
|
||||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)");
|
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||||
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
|
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape.");
|
||||||
"SmemLayoutAtom must evenly divide tile shape.");
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape.");
|
||||||
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
|
|
||||||
"SmemLayoutAtom must evenly divide tile shape.");
|
|
||||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||||
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
||||||
|
|
||||||
@ -275,8 +265,8 @@ struct CollectiveMma<
|
|||||||
using SmemAllocTypeA = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
|
using SmemAllocTypeA = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
|
||||||
using SmemAllocTypeB = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
|
using SmemAllocTypeB = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
|
||||||
|
|
||||||
using BitTypeElementA = uint_bit_t<cute::sizeof_bits_v<ElementA>>;
|
using BitTypeElementA = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
|
||||||
using BitTypeElementB = uint_bit_t<cute::sizeof_bits_v<ElementB>>;
|
using BitTypeElementB = cute::uint_bit_t<cute::sizeof_bits_v<ElementB>>;
|
||||||
|
|
||||||
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
|
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
|
||||||
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
|
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
|
||||||
@ -308,15 +298,22 @@ struct CollectiveMma<
|
|||||||
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
||||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||||
|
|
||||||
|
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
|
||||||
static constexpr uint32_t SFTransactionBytes =
|
static constexpr uint32_t SFTransactionBytes =
|
||||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v<ElementSF>) +
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v<ElementSF>) +
|
||||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v<ElementSF>);
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v<ElementSF>);
|
||||||
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
|
|
||||||
static constexpr uint32_t ABTmaTransactionBytes =
|
static constexpr uint32_t ABTmaTransactionBytes =
|
||||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
|
||||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
|
||||||
static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes;
|
static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes;
|
||||||
|
|
||||||
|
template <class AccTensor, class SfaTensor, class SfbTensor>
|
||||||
|
struct TmemStorage {
|
||||||
|
AccTensor accumulators;
|
||||||
|
SfaTensor tCtSFA;
|
||||||
|
SfbTensor tCtSFB;
|
||||||
|
};
|
||||||
|
|
||||||
// Host side kernel arguments
|
// Host side kernel arguments
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
ArrayElementA const** ptr_A{nullptr};
|
ArrayElementA const** ptr_A{nullptr};
|
||||||
@ -401,7 +398,11 @@ struct CollectiveMma<
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
|
CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
|
||||||
: cluster_shape_(cluster_shape)
|
: cluster_shape_(cluster_shape)
|
||||||
, block_rank_in_cluster_(block_rank_in_cluster) {
|
, block_rank_in_cluster_(block_rank_in_cluster)
|
||||||
|
, layout_SFA_(params.layout_SFA)
|
||||||
|
, layout_SFB_(params.layout_SFB)
|
||||||
|
, runtime_data_type_a_(params.runtime_data_type_a)
|
||||||
|
, runtime_data_type_b_(params.runtime_data_type_b) {
|
||||||
if constexpr (IsDynamicCluster) {
|
if constexpr (IsDynamicCluster) {
|
||||||
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
|
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
|
||||||
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
|
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
|
||||||
@ -613,18 +614,48 @@ struct CollectiveMma<
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Construct A Single Stage's Accumulator Shape
|
/// Construct A Single Stage's Accumulator Shape
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE static
|
||||||
|
auto
|
||||||
partition_accumulator_shape() {
|
partition_accumulator_shape() {
|
||||||
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
|
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
|
||||||
|
|
||||||
return acc_shape;
|
return acc_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class TmemStorage>
|
||||||
|
CUTLASS_DEVICE static
|
||||||
|
auto
|
||||||
|
slice_accumulator(TmemStorage tmem_storage, int stage) {
|
||||||
|
return tmem_storage.accumulators(_,_,_,stage);
|
||||||
|
}
|
||||||
|
|
||||||
template <class FrgEngine, class FrgLayout>
|
template <class EpilogueTile, bool IsOverlappingAccum = false>
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE static
|
||||||
slice_accumulator(cute::Tensor<FrgEngine, FrgLayout> const& accumulators, int stage) {
|
auto
|
||||||
return accumulators(_,_,_,stage);
|
init_tmem_tensors(EpilogueTile epi_tile) {
|
||||||
|
TiledMma tiled_mma;
|
||||||
|
auto acc_shape = partition_accumulator_shape();
|
||||||
|
// ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue.
|
||||||
|
Tensor accumulators = cutlass::detail::make_sm100_accumulator<AccumulatorPipelineStageCount, IsOverlappingAccum>(
|
||||||
|
tiled_mma, acc_shape, EpilogueTile{});
|
||||||
|
Tensor tCtSFA = make_tensor<typename TiledMma::FrgTypeSFA>(shape(SmemLayoutAtomSFA{}));
|
||||||
|
Tensor tCtSFB = make_tensor<typename TiledMma::FrgTypeSFB>(shape(SmemLayoutAtomSFB{}));
|
||||||
|
|
||||||
|
TmemStorage<decltype(accumulators), decltype(tCtSFA), decltype(tCtSFB)> tmem_storage;
|
||||||
|
tmem_storage.accumulators = accumulators;
|
||||||
|
tmem_storage.tCtSFA = tCtSFA;
|
||||||
|
tmem_storage.tCtSFB = tCtSFB;
|
||||||
|
|
||||||
|
return tmem_storage;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TmemStorage>
|
||||||
|
CUTLASS_DEVICE static
|
||||||
|
void
|
||||||
|
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
|
||||||
|
tmem_storage.accumulators.data() = tmem_base_addr;
|
||||||
|
tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators);
|
||||||
|
tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set up the data needed by this collective for load.
|
/// Set up the data needed by this collective for load.
|
||||||
@ -693,9 +724,9 @@ struct CollectiveMma<
|
|||||||
}
|
}
|
||||||
else if constexpr (IsCtaN64) {
|
else if constexpr (IsCtaN64) {
|
||||||
Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB));
|
Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB));
|
||||||
auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp),
|
auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp),
|
||||||
make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp));
|
make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp));
|
||||||
auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp),
|
auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp),
|
||||||
make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp));
|
make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp));
|
||||||
return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride));
|
return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride));
|
||||||
}
|
}
|
||||||
@ -707,7 +738,6 @@ struct CollectiveMma<
|
|||||||
Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l)
|
Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l)
|
||||||
Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l)
|
Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l)
|
||||||
|
|
||||||
|
|
||||||
// Partition for this CTA
|
// Partition for this CTA
|
||||||
ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
|
ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
|
||||||
|
|
||||||
@ -770,17 +800,15 @@ struct CollectiveMma<
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set up the data needed by this collective for mma compute.
|
/// Set up the data needed by this collective for mma compute.
|
||||||
template <class FrgEngine, class FrgLayout>
|
template <class TmemStorage>
|
||||||
CUTLASS_DEVICE auto
|
CUTLASS_DEVICE auto
|
||||||
mma_init(
|
mma_init(
|
||||||
Params const& params,
|
TmemStorage tmem_storage,
|
||||||
[[maybe_unused]] cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
|
TensorStorage& shared_tensors) const {
|
||||||
TensorStorage& shared_tensors,
|
|
||||||
uint32_t const tmem_offset) const {
|
|
||||||
|
|
||||||
// Allocate "fragments/descriptors" for A and B matrices
|
// Allocate "fragments/descriptors" for A and B matrices
|
||||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||||
|
|
||||||
// Allocate "fragments/descriptors" for A and B matrices
|
// Allocate "fragments/descriptors" for A and B matrices
|
||||||
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||||
@ -792,13 +820,8 @@ struct CollectiveMma<
|
|||||||
//
|
//
|
||||||
// Scale Factor
|
// Scale Factor
|
||||||
//
|
//
|
||||||
Tensor tCtSFA = make_tensor<typename TiledMma::FrgTypeSFA>(shape(SmemLayoutAtomSFA{}));
|
Tensor tCtSFA = tmem_storage.tCtSFA;
|
||||||
// Set tCtSFA and tCtSFB start addresses. Only update the TMEM column address by masking the address with 0x000001FF.
|
Tensor tCtSFB = tmem_storage.tCtSFB;
|
||||||
// TMEM allocations for SFA and SFB will always start at DP 0.
|
|
||||||
tCtSFA.data() = tmem_offset;
|
|
||||||
Tensor tCtSFB = make_tensor<typename TiledMma::FrgTypeSFB>(shape(SmemLayoutAtomSFB{}));
|
|
||||||
tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA);
|
|
||||||
|
|
||||||
// Setup smem descriptors for UTCCP
|
// Setup smem descriptors for UTCCP
|
||||||
Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{});
|
Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{});
|
||||||
Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{});
|
Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{});
|
||||||
@ -831,8 +854,10 @@ struct CollectiveMma<
|
|||||||
TiledMma tiled_mma;
|
TiledMma tiled_mma;
|
||||||
|
|
||||||
if constexpr (IsRuntimeDataType) {
|
if constexpr (IsRuntimeDataType) {
|
||||||
tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111;
|
// Update instruction descriptor according to runtime argument.
|
||||||
tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111;
|
// Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe.
|
||||||
|
tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111;
|
||||||
|
tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111;
|
||||||
}
|
}
|
||||||
|
|
||||||
return cute::make_tuple(
|
return cute::make_tuple(
|
||||||
@ -997,45 +1022,52 @@ struct CollectiveMma<
|
|||||||
//
|
//
|
||||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||||
|
|
||||||
if (k_tile_count > 0) { // first iteraion
|
if constexpr (IsOverlappingAccum) {
|
||||||
// WAIT on mainloop_pipe_consumer_state until its data are available
|
// first iteration manual unroll for tmem overlap kernel
|
||||||
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
|
if (k_tile_count > 0) {
|
||||||
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
|
// WAIT on mainloop_pipe_consumer_state until its data are available
|
||||||
|
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
|
||||||
|
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
|
||||||
|
|
||||||
// Compute on k_tile
|
// Compute on k_tile
|
||||||
int read_stage = mainloop_pipe_consumer_state.index();
|
int read_stage = mainloop_pipe_consumer_state.index();
|
||||||
// Save current mainlop pipeline read state
|
// Save current mainlop pipeline read state
|
||||||
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
|
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
|
||||||
|
|
||||||
// Advance mainloop_pipe
|
// Advance mainloop_pipe
|
||||||
++mainloop_pipe_consumer_state;
|
++mainloop_pipe_consumer_state;
|
||||||
--k_tile_count;
|
--k_tile_count;
|
||||||
skip_wait = k_tile_count <= 0;
|
skip_wait = k_tile_count <= 0;
|
||||||
// Peek at next iteration
|
// Peek at next iteration
|
||||||
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
||||||
|
|
||||||
if (cute::elect_one_sync()) {
|
if (cute::elect_one_sync()) {
|
||||||
copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t);
|
copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t);
|
||||||
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
|
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (IsOverlappingAccum) {
|
// Wait for tmem accumulator buffer to become empty with a flipped phase
|
||||||
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
|
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
|
||||||
}
|
|
||||||
|
|
||||||
// Unroll the K mode manually so we can set scale C to 1
|
// Unroll the K mode manually so we can set scale C to 1
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||||
// (V,M) x (V,N) => (V,M,N)
|
// (V,M) x (V,N) => (V,M,N)
|
||||||
cute::gemm(tiled_mma.with(tiled_mma.accumulate_,
|
cute::gemm(tiled_mma.with(tiled_mma.accumulate_,
|
||||||
tCtSFA(_,_,k_block),
|
tCtSFA(_,_,k_block),
|
||||||
tCtSFB_mma(_,_,k_block)),
|
tCtSFB_mma(_,_,k_block)),
|
||||||
tCrA(_,_,k_block,read_stage),
|
tCrA(_,_,k_block,read_stage),
|
||||||
tCrB(_,_,k_block,read_stage),
|
tCrB(_,_,k_block,read_stage),
|
||||||
accumulators);
|
accumulators);
|
||||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||||
|
}
|
||||||
|
|
||||||
|
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
||||||
}
|
}
|
||||||
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
}
|
||||||
|
else {
|
||||||
|
// Wait for tmem accumulator buffer to become empty with a flipped phase
|
||||||
|
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
@ -1073,6 +1105,7 @@ struct CollectiveMma<
|
|||||||
accumulators);
|
accumulators);
|
||||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||||
}
|
}
|
||||||
|
|
||||||
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1273,6 +1306,11 @@ protected:
|
|||||||
typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr};
|
typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr};
|
||||||
typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr};
|
typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr};
|
||||||
|
|
||||||
|
LayoutSFA layout_SFA_;
|
||||||
|
LayoutSFB layout_SFB_;
|
||||||
|
RuntimeDataTypeA runtime_data_type_a_{};
|
||||||
|
RuntimeDataTypeB runtime_data_type_b_{};
|
||||||
|
|
||||||
ClusterShape cluster_shape_;
|
ClusterShape cluster_shape_;
|
||||||
uint32_t block_rank_in_cluster_;
|
uint32_t block_rank_in_cluster_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -47,7 +47,6 @@
|
|||||||
#include "cute/arch/cluster_sm90.hpp"
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
#include "cute/atom/mma_atom.hpp"
|
#include "cute/atom/mma_atom.hpp"
|
||||||
#include "cute/algorithm/gemm.hpp"
|
#include "cute/algorithm/gemm.hpp"
|
||||||
#include "cute/tensor_predicate.hpp"
|
|
||||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -123,7 +122,7 @@ struct CollectiveMma<
|
|||||||
"Static cluster shape used: TileShape should be evenly divided by TiledMma");
|
"Static cluster shape used: TileShape should be evenly divided by TiledMma");
|
||||||
|
|
||||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
|
using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
|
||||||
static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or
|
static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or
|
||||||
shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256,
|
shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256,
|
||||||
"Cta N should be one of 64/128/192/256");
|
"Cta N should be one of 64/128/192/256");
|
||||||
|
|
||||||
@ -726,9 +725,9 @@ struct CollectiveMma<
|
|||||||
}
|
}
|
||||||
else if constexpr (IsCtaN64) {
|
else if constexpr (IsCtaN64) {
|
||||||
Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_));
|
Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_));
|
||||||
auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp),
|
auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp),
|
||||||
make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp));
|
make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp));
|
||||||
auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp),
|
auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp),
|
||||||
make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp));
|
make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp));
|
||||||
return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride));
|
return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride));
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user