Compare commits
263 Commits
thakkarV-p
...
Deepseek
| Author | SHA1 | Date | |
|---|---|---|---|
| e9a75581fe | |||
| ac210faef8 | |||
| 15f5468872 | |||
| af5519d938 | |||
| 415d587ebf | |||
| eefa171318 | |||
| afa1772203 | |||
| 9b3772dfa6 | |||
| b84e9802d8 | |||
| e9627ce55b | |||
| ad6e1ec19c | |||
| 0642d46dd4 | |||
| 833f6990e0 | |||
| affd1b693d | |||
| 6f55278121 | |||
| 3c28697b9f | |||
| bdd641790a | |||
| cc19d4d22b | |||
| 47daa33c61 | |||
| 389e493055 | |||
| 9eb01fa0b0 | |||
| b78588d163 | |||
| 902dff3663 | |||
| ef5620dd1d | |||
| 375e284e6a | |||
| 52b35e90ce | |||
| 24f991e879 | |||
| 51b25e7b58 | |||
| 7de6a59784 | |||
| c506e16788 | |||
| 7494a180a4 | |||
| cffd5d32b7 | |||
| bf9da7b76c | |||
| 3d261a5974 | |||
| e1cd8c7866 | |||
| 33c584364e | |||
| 2b6cfd34d1 | |||
| 4c42f73fda | |||
| 80243e0b8c | |||
| b0e09d7cd3 | |||
| 8aa95dbb88 | |||
| d656afbd2a | |||
| 32e3c38aef | |||
| 9004ed2d1b | |||
| 19f51596e8 | |||
| e8a8b69365 | |||
| 08a49953a0 | |||
| a424ca6cf9 | |||
| be692b48b0 | |||
| 12626bcfe4 | |||
| f02913c34e | |||
| 03e3bffaec | |||
| e5f3caf145 | |||
| 83ae20c740 | |||
| b0c09ed077 | |||
| ea69cc2849 | |||
| f3a3bfcbf2 | |||
| d65266a868 | |||
| 5b50a8faaf | |||
| 08101d9d0c | |||
| 755194a7bd | |||
| 53668799b2 | |||
| cc3c29a81a | |||
| 0837a2a00a | |||
| 477a677317 | |||
| b27c49e84a | |||
| e2b0789927 | |||
| 44dae8b90e | |||
| 2991ce18d3 | |||
| 1ebda1ccef | |||
| 9f68995de5 | |||
| 3a8c01a18b | |||
| dbdae514e0 | |||
| 21d0534167 | |||
| 323c8170bf | |||
| 82f5075946 | |||
| 06e337758d | |||
| 7369adcaca | |||
| 6c3044136b | |||
| e1976daacc | |||
| f7b19de32c | |||
| 4dbf5dbed2 | |||
| f93a69134e | |||
| 3f084f7f3c | |||
| b0296bf682 | |||
| 865be73a97 | |||
| 8d8cfdf375 | |||
| fb170439e8 | |||
| 4e5a8f6853 | |||
| 7192f4ab23 | |||
| 2049c6c5a2 | |||
| e22ba590cd | |||
| 19b4c5e065 | |||
| 06b21349bc | |||
| eee0cab26c | |||
| 36cbfcf483 | |||
| 1f2b590da6 | |||
| 8b2a0408bd | |||
| fbd116c0e5 | |||
| 5b283c872c | |||
| be60a0b272 | |||
| 56b46e2d13 | |||
| 52fb43f30f | |||
| 843adf0408 | |||
| e48c7618e4 | |||
| c5239d8312 | |||
| d6580c3dc0 | |||
| 81b06ee0e0 | |||
| dbfced05e7 | |||
| 2448bb56e6 | |||
| 637b159063 | |||
| 033d9efd2d | |||
| acc3ee18a1 | |||
| 5c447dd84f | |||
| 7d49e6c7e2 | |||
| a40e08e9d5 | |||
| 8e7d9f483d | |||
| 19f3cc33f1 | |||
| f9ece1b42c | |||
| 28cbacbf64 | |||
| 8f7d2789b8 | |||
| c4e3e122e2 | |||
| 629f4653c3 | |||
| ffa34e7075 | |||
| a8f2c80db0 | |||
| bbe579a9e3 | |||
| 47a3ebbea9 | |||
| 57e01e1a6b | |||
| 6e3df975a2 | |||
| 8825fbf1ef | |||
| 092f14db05 | |||
| 9385141f19 | |||
| b4b5b11070 | |||
| 139b93db61 | |||
| ca37d632c9 | |||
| 362abbf274 | |||
| 751eb9a885 | |||
| 2f589ffa76 | |||
| acba5beee5 | |||
| 74d1f3e63a | |||
| 8ac2edc810 | |||
| d4be5ab5d7 | |||
| c9591a694d | |||
| 5c756eb774 | |||
| 8236f30675 | |||
| b7508e3379 | |||
| f60786b536 | |||
| 30ec1a4649 | |||
| e1483d5fa0 | |||
| f4a0216601 | |||
| f188f9b709 | |||
| 9c9b51d35c | |||
| a75b4ac483 | |||
| e9e30c2304 | |||
| 4a1709e17e | |||
| bef1fbcbe6 | |||
| 2375a07d01 | |||
| 60c8251b72 | |||
| 10b850f9c7 | |||
| 99c4eebe3b | |||
| a759e85f5f | |||
| 56fc3df03b | |||
| eb01d5449d | |||
| 8098336d51 | |||
| b5d8a5d9cc | |||
| 6e60b9b17c | |||
| 1ab6cc7b68 | |||
| 5ae8133cfa | |||
| 39c6a83f23 | |||
| 1d7f2a207e | |||
| 557be3ab0e | |||
| c008b4aea8 | |||
| 922fb5108b | |||
| 7a7796afae | |||
| fb10fa5308 | |||
| 5e1a0a5adb | |||
| 757275f279 | |||
| fa8dfe631f | |||
| 112590114d | |||
| ff02da2667 | |||
| 4082fed85a | |||
| 5f13dcad78 | |||
| 61a38f83dc | |||
| ff61a49dd1 | |||
| 26986bbc60 | |||
| 7d8317a63e | |||
| 5cd735c48e | |||
| 67ae8e0603 | |||
| 14f69bddc8 | |||
| 90d3b0fb18 | |||
| e0aaa3c3b3 | |||
| 8783c41851 | |||
| 6407bcdf0a | |||
| a77b2c9cb8 | |||
| 34bbadd3ff | |||
| 88c0d7c726 | |||
| e01b9b5029 | |||
| 34fd98056b | |||
| 3a8f57a3c8 | |||
| 6673df0e48 | |||
| 7618e9bfd8 | |||
| a88c41cf8d | |||
| 27de343535 | |||
| 2a9fa23e06 | |||
| 2e56cfabee | |||
| 3930f709ce | |||
| 7e5ee8b7bf | |||
| 2d9a557427 | |||
| 4575443d44 | |||
| a0d787b746 | |||
| d20f3a9542 | |||
| 8e85580859 | |||
| 146d314057 | |||
| f679663224 | |||
| e066ced33b | |||
| 9b923dd4c4 | |||
| f6d42f2dd0 | |||
| 473a67073e | |||
| 87349d3496 | |||
| fde824af21 | |||
| 7dbf423763 | |||
| 6f47420213 | |||
| 4638250469 | |||
| 7859fe322a | |||
| d3e72719b4 | |||
| b4ab501767 | |||
| f079619f5e | |||
| 13f413493a | |||
| 6fbc0d3380 | |||
| b97404837e | |||
| e2953d47c5 | |||
| 19c4a4815e | |||
| fcfbd23e26 | |||
| b250faccd3 | |||
| 24c8b7d8a2 | |||
| 7c04f95415 | |||
| 6f8596ce3f | |||
| fe2f491dd7 | |||
| df02482f1d | |||
| 180c5629bf | |||
| e36912f961 | |||
| 9a83bd3381 | |||
| 54bebe417d | |||
| 43cfbe0086 | |||
| 4a68cf748e | |||
| d572cc1aab | |||
| 9b8166e3f0 | |||
| e2d439ee7e | |||
| 0435979f59 | |||
| 2ba1ef10be | |||
| 0964bdb64c | |||
| ecbd24566c | |||
| 660a05f581 | |||
| bc36122c3f | |||
| 15d9d31f1f | |||
| 1eef5c3cf1 | |||
| 87070b6d51 | |||
| 77549ae6c8 | |||
| 42290f5d1c | |||
| 209faf7b94 | |||
| 6116706c96 | |||
| 2670b973dd | |||
| af332d4aa9 |
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: CUTLASS Discord
|
||||
url: https://discord.gg/nvidiadeveloper
|
||||
about: Come chat about using and contributing to CUTLASS!
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +1,4 @@
|
||||
# PyCache files
|
||||
__pycache__/
|
||||
cutlass_library.egg-info/
|
||||
/build*
|
||||
|
||||
405
CHANGELOG.md
405
CHANGELOG.md
@ -1,33 +1,244 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [3.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.8.0) (2025-01-25)
|
||||
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
- Support for mixed input GEMM kernels on Hopper in the profiler.
|
||||
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
|
||||
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
- A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
|
||||
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
|
||||
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
- Updates to [profiler documentation](./media/docs/profiler.md) for testing mixed input GEMM kernels on Hopper.
|
||||
|
||||
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
|
||||
- [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
|
||||
- [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b).
|
||||
- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- Enabled high precision accumulation for Hopper FP8 Sparse GEMM.
|
||||
- Potential API breaking changes:
|
||||
+ Fix `cute::UniversalCopy` for type safety.
|
||||
+ No longer implicitly select `cute::SM80_CP_ASYNC_*` based on input tensors. This avoids implicit downstream synchronization requirements. To use `SM80_CP_ASYNC`, users must explicitly select the appropriate CopyAtom.
|
||||
+ Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication.
|
||||
+ Remove `cute::copy_vec<T>` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment<NumBits>,...)`.
|
||||
+ A refactor of default epilogue struct `DefaultEpilogue` [API](./include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel.
|
||||
- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/profiler.md#cutlass-profiler).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03)
|
||||
|
||||
- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu).
|
||||
+ [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu)
|
||||
+ [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu)
|
||||
+ [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu)
|
||||
+ [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu)
|
||||
- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API.
|
||||
- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode.
|
||||
- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu).
|
||||
- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md).
|
||||
- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details.
|
||||
- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support.
|
||||
- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp).
|
||||
- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations.
|
||||
- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper).
|
||||
- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h)
|
||||
- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu).
|
||||
- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
- Optimal code generation with CUDA toolkit versions 12.6.
|
||||
|
||||
## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25)
|
||||
|
||||
- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu)
|
||||
- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48)
|
||||
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and
|
||||
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
|
||||
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
|
||||
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
|
||||
+ [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
|
||||
+ [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
|
||||
+ [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
|
||||
- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
|
||||
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py).
|
||||
- Support for residual add (beta != 0) in convolution kernels.
|
||||
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
|
||||
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
|
||||
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md).
|
||||
- Better support for MSVC as a host compiler.
|
||||
- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2.
|
||||
- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1.
|
||||
|
||||
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)
|
||||
|
||||
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
|
||||
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
|
||||
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
|
||||
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
|
||||
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
|
||||
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
|
||||
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
|
||||
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x
|
||||
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
|
||||
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
|
||||
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
|
||||
+ [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
|
||||
+ [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
|
||||
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
|
||||
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
|
||||
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
|
||||
- Fixes to greatly reduce build warnings.
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
|
||||
|
||||
- Statically available [CUTLASS Version macros](./include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMMs](./examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Updates and bugfixes from the community (thanks!).
|
||||
|
||||
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
|
||||
* Expanded [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved.
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
|
||||
* [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
|
||||
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
|
||||
* Profiler support for lower-aligned Hopper GEMMs.
|
||||
* Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion).
|
||||
* Sub-Byte type fixes and improvements.
|
||||
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](./include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
|
||||
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
|
||||
* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
|
||||
|
||||
## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2) (2023-10-25)
|
||||
* Minor patch for issue/1138
|
||||
|
||||
## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22)
|
||||
* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
|
||||
* SM80 EVT support in C++ and Python.
|
||||
* Other SM90 epilogue improvements.
|
||||
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
|
||||
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](./python/README.md) for details.
|
||||
* SM90 TF32 kernel improvements for all layouts.
|
||||
* SM90 rasterization direction support in the CUTLASS profiler.
|
||||
* Improvement for CUTLASS profiler build times.
|
||||
* Remove Python-C++ bindings.
|
||||
|
||||
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
|
||||
|
||||
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](./examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
|
||||
* New [Epilogue Visitor Tree (EVT)](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
|
||||
* [Stream-K](./include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
|
||||
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](./include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
|
||||
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* [Hopper GEMM+Permute](./examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
|
||||
* New CUTLASS 2D Convolution Python interface. New [example](./examples/python/03_basic_conv2d.ipynb) here.
|
||||
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
|
||||
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
|
||||
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](./python/README.md) and new [examples](./examples/python).
|
||||
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
|
||||
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
|
||||
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
|
||||
* New [*warp-specialized persistent cooperative*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
|
||||
* An [example](./examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
|
||||
* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
|
||||
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
|
||||
* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
|
||||
* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
|
||||
* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
|
||||
* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
|
||||
* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
|
||||
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
|
||||
* [Permute + GEMM fusion](./examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
|
||||
* [Row Broadcast](./include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
|
||||
* The GitHub branch is renamed from `master` to `main` in this release.
|
||||
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
|
||||
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](/media/docs/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](media/docs/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](/README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](/README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
|
||||
* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md).
|
||||
* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md).
|
||||
* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3.
|
||||
* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture).
|
||||
* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters.
|
||||
* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations.
|
||||
* [CUTLASS library integration](/tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](examples/48_hopper_warp_specialized_gemm), [49](examples/49_hopper_gemm_schedules_with_collective_builder), and [50](examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
* [CUTLASS library integration](./tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler.
|
||||
* Support for [Hopper GEMMs](./examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features.
|
||||
* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](./examples/48_hopper_warp_specialized_gemm), [49](./examples/49_hopper_gemm_schedules_with_collective_builder), and [50](./examples/50_hopper_gemm_with_epilogue_swizzle).
|
||||
|
||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||
* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* [Stream-K](./examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](./examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](./examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](./test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
* [BLAS3](./test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions.
|
||||
* [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary.
|
||||
* Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N.
|
||||
* [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added
|
||||
* [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM.
|
||||
* The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration.
|
||||
* [kFixedStrideDilation](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
|
||||
* [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded.
|
||||
* The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration.
|
||||
* [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/).
|
||||
* [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115).
|
||||
* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers).
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
@ -36,54 +247,54 @@
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
|
||||
* [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours.
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too.
|
||||
* Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance.
|
||||
* [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing.
|
||||
* [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues.
|
||||
* [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue.
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes:
|
||||
* kSingleGroup: output channel per group is multiple of Threadblock tile N.
|
||||
* kMultipleGroup: Threadblock tile N is multiple of output channel per group.
|
||||
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
|
||||
* Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number.
|
||||
* Standalone [Layernorm](./tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](./tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels.
|
||||
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [First layer Convolution kernels](./test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](./test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](./python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
|
||||
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/symm_operation.py)
|
||||
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/trmm_operation.py)
|
||||
* [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* [HERK](./test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
|
||||
* [SYRK](./test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py)
|
||||
* [SYMM](./test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/symm_operation.py)
|
||||
* [TRMM](./test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/trmm_operation.py)
|
||||
* [Unit tests](./test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](./examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](./tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](./examples/35_gemm_softmax)
|
||||
* [Gather and Scatter Fusion with GEMM](./examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
* It can select random rows in a row major matrix.
|
||||
* It can select random columns in a column major matrix.
|
||||
* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV.
|
||||
* Supported kernels: GEMM and CONV.
|
||||
* Supported types: fp16 and int8.
|
||||
* Supported architectures: Turing and Ampere.
|
||||
* [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Transposed Convolution](./examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
* [Utility functions](./tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
* [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels.
|
||||
* Epilogue enhancement:
|
||||
* Eliminate bank conflicts in int8 tensor core kernels.
|
||||
* Half2 usage if epilogue compute type is fp16.
|
||||
* More activation functions: Silu, Hardswish, Leaky Relu.
|
||||
* New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* New elementwise fusion pattern for [residual block](./include/cutlass/epilogue/thread/linear_combination_residual_block.h).
|
||||
* [Group GEMM](./examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs.
|
||||
* [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
@ -93,17 +304,17 @@
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
* 45+ TFLOPs on NVIDIA A100
|
||||
* [GEMM SDK example](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* [GEMM SDK example](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
|
||||
* [Conv Fprop SDK example](/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
|
||||
* [SDK example](/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](/include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](./include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](./examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates from the community (thanks!)
|
||||
|
||||
@ -113,11 +324,11 @@
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](./examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](./include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](./include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/)
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](./examples/13_two_tensor_op_fusion/)
|
||||
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
|
||||
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
|
||||
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
|
||||
@ -132,27 +343,27 @@
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](./include/cutlass/arch/memory.h) and [global load](./include/cutlass/arch/memory_sm80.h)
|
||||
* Fused operators with GEMM and Convolution
|
||||
* [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* 64b tensor strides and leading dimensions support for GEMMs
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](/include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
|
||||
* Accelerates over previous implementation by cutting down redundant math by 4x
|
||||
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
|
||||
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
|
||||
* Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_gemm/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Updates to [quaternion.h](./include/cutlass/quaternion.h) and [functional.h](./include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](./examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](./examples/22_quaternion_conv/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](./test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](./test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Updated minimum CUDA Toolkit requirement to 10.2
|
||||
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
@ -161,17 +372,17 @@
|
||||
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
|
||||
* Tensor reductions
|
||||
* _m_-to-_n_ reductions of tensors with affine layout
|
||||
* [Specializations](/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* [Specializations](./test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](./test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* Custom reduction functors such as `cutlass::logical_and`
|
||||
* Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31)
|
||||
* Optimizations for 3-D convolution
|
||||
* [Optimized tile iterators](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* [Optimized tile iterators](./include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* Full coverage of [forward](test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
|
||||
* [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md)
|
||||
* [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md)
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
|
||||
|
||||
## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19)
|
||||
* Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs
|
||||
@ -179,11 +390,11 @@
|
||||
* Data type: FP32, complex<FP32>, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32
|
||||
* Spatial dimensions: 1-D, 2-D, and 3-D
|
||||
* Layout: NHWC, NCxHWx
|
||||
* Implicit GEMM convolution components:
|
||||
* Implicit GEMM convolution components:
|
||||
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
|
||||
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
|
||||
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
|
||||
* [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
* [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
@ -191,21 +402,21 @@
|
||||
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
|
||||
* Minor Features:
|
||||
* [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](/include/cutlass/constants.h)
|
||||
* [Activation functions](./include/cutlass/epilogue/thread/activation.h) such as [GeLU](./include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](./include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](./include/cutlass/constants.h)
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue)
|
||||
* [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* Fast Tensor Core operations:
|
||||
* Fast Tensor Core operations:
|
||||
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Tensor Float 32, BFloat16, and double-precision data types
|
||||
* Mixed integer data types (int8, int4, bin1)
|
||||
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
|
||||
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
|
||||
* Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required)
|
||||
* Features:
|
||||
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
|
||||
@ -217,11 +428,11 @@
|
||||
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
* [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu)
|
||||
* [SDK Examples of Planar Complex GEMMs](./examples/10_planar_complex/planar_complex.cu)
|
||||
* Minor enhancements and bug fixes
|
||||
|
||||
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
|
||||
@ -231,10 +442,10 @@
|
||||
* Encapsulated functionality embodying modern C++11 programming techniques
|
||||
* Optimized containers and data types for efficient, generic, portable device code
|
||||
* Updates to:
|
||||
* [Quick start guide](/media/docs/quickstart.md)
|
||||
* [Documentation](/README.md#documentation)
|
||||
* [Utilities](/media/docs/utilities.md)
|
||||
* [CUTLASS Profiler](/media/docs/profiler.md)
|
||||
* [Quick start guide](./media/docs/quickstart.md)
|
||||
* [Documentation](./README.md#documentation)
|
||||
* [Utilities](./media/docs/utilities.md)
|
||||
* [CUTLASS Profiler](./media/docs/profiler.md)
|
||||
* Native Turing Tensor Cores
|
||||
* Efficient GEMM kernels targeting Turing Tensor Cores
|
||||
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
|
||||
@ -298,7 +509,7 @@
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
|
||||
696
CMakeLists.txt
696
CMakeLists.txt
File diff suppressed because it is too large
Load Diff
133
CONTRIBUTORS.md
133
CONTRIBUTORS.md
@ -1,49 +1,105 @@
|
||||

|
||||

|
||||
|
||||
[README](/README.md#documentation) > **Contributors**
|
||||
[README](./README.md#documentation) > **Contributors**
|
||||
|
||||
# CUTLASS Developers and Contributors
|
||||
# CUTLASS Developers **
|
||||
|
||||
This is the official list of CUTLASS developers and contributors.
|
||||
|
||||
## DEVELOPERS
|
||||
Vijay Thakkar<br />
|
||||
Pradeep Ramani<br />
|
||||
Cris Cecka<br />
|
||||
Aniket Shivam<br />
|
||||
Jack Kosaian<br />
|
||||
Mark Hoemmen<br />
|
||||
Honghao Lu<br />
|
||||
Ethan Yan<br />
|
||||
Haicheng Wu<br />
|
||||
Andrew Kerr<br />
|
||||
Dustyn Blasig<br />
|
||||
Fengqi Qiao<br />
|
||||
Duane Merrill<br />
|
||||
Yujia Zhai<br />
|
||||
Shang Zhang<br />
|
||||
Piotr Majcher<br />
|
||||
Paul Springer<br />
|
||||
Markus Hohnerbach<br />
|
||||
Jin Wang<br />
|
||||
Dustyn Blasig<br />
|
||||
Albert Xu<br />
|
||||
Junkai Wu<br />
|
||||
Xiuxia Zhang<br />
|
||||
Haicheng Wu<br />
|
||||
Jack Yang<br />
|
||||
Pradeep Ramani<br />
|
||||
Aditya Atluri<br />
|
||||
Han Li<br />
|
||||
Nick Zhao<br />
|
||||
Ivan Yin<br />
|
||||
Yu-Jung Chen<br />
|
||||
Markus Hoehnerbach<br />
|
||||
Honghao Lu<br />
|
||||
Mihir Awatramani<br />
|
||||
Hao Sheng<br />
|
||||
Zekun Fan<br />
|
||||
Aniket Shivam<br />
|
||||
Siyu Liu<br />
|
||||
Richard Cai<br />
|
||||
Vikas Gupta<br />
|
||||
Ethan Yan<br />
|
||||
Vijay Thakkar<br />
|
||||
Cris Cecka<br />
|
||||
Lawrence Ryan<br />
|
||||
Qun Song<br />
|
||||
Daniel Ricketts<br />
|
||||
dePaul Miller<br />
|
||||
Yuhan Li<br />
|
||||
Saman Ashkiani<br />
|
||||
Jack Chen<br />
|
||||
Shang Zhang<br />
|
||||
Petrick Liu<br />
|
||||
Questa Wang<br />
|
||||
Pramod Shenoy<br />
|
||||
Jack Kosaian<br />
|
||||
Yujia Zhai<br />
|
||||
Zhaodong Chen<br />
|
||||
Manas Sahni<br />
|
||||
Shunfan Shao<br />
|
||||
Fengqi Qiao<br />
|
||||
Serif Yesil<br />
|
||||
Aragorn Guan<br />
|
||||
Heidi He<br />
|
||||
Xiao Song<br />
|
||||
Sergey Klevtsov<br />
|
||||
Jiang Shao<br />
|
||||
Ruqing Xu<br />
|
||||
Mengyu Guo<br />
|
||||
Tao Xie<br />
|
||||
Linfeng Zheng<br />
|
||||
Harrison Barclay<br />
|
||||
Wenfei Tang<br />
|
||||
Diksha Gohlyan<br />
|
||||
Alexander Zhurkevich<br />
|
||||
Siyuan Fu<br />
|
||||
Hua Huang<br />
|
||||
Xiufan Liang<br />
|
||||
Ian Tramble<br />
|
||||
Ali Hassani<br />
|
||||
Shreya Gaur<br />
|
||||
|
||||
** _The list is sorted in order of the author's first contribution to the CUTLASS project._
|
||||
|
||||
|
||||
# CUTE Developers
|
||||
|
||||
## CuTe
|
||||
Cris Cecka<br />
|
||||
Vijay Thakkar<br />
|
||||
|
||||
## CUTLASS Product Manager
|
||||
|
||||
# CUTLASS Product Manager
|
||||
|
||||
Matthew Nicely<br />
|
||||
|
||||
## Former CUTLASS Developers
|
||||
Manish Gupta<br />
|
||||
Naila Farooqui<br />
|
||||
David Tanner<br />
|
||||
Manikandan Ananth<br />
|
||||
Zhaodong Chen<br />
|
||||
Chinmay Talegaonkar<br />
|
||||
|
||||
## CONTRIBUTORS
|
||||
# Former CUTLASS Developers
|
||||
|
||||
Manish Gupta<br />
|
||||
Duane Merrill<br />
|
||||
Piotr Majcher<br />
|
||||
Naila Farooqui<br />
|
||||
Mark Hoemmen<br />
|
||||
Rawn Henry<br />
|
||||
Jin Wang<br />
|
||||
Timmy Liu<br />
|
||||
Manikandan Ananth<br />
|
||||
David Tanner<br />
|
||||
|
||||
|
||||
# Acknowledgements
|
||||
|
||||
Tri Dao<br />
|
||||
Jay Shah<br />
|
||||
Timothy Costa<br />
|
||||
Julien Demouth<br />
|
||||
Brian Fahs<br />
|
||||
@ -53,24 +109,15 @@ Mostafa Hagog<br />
|
||||
Fei Hu<br />
|
||||
Alan Kaatz<br />
|
||||
Tina Li<br />
|
||||
Timmy Liu<br />
|
||||
Wei Liu<br />
|
||||
Duane Merrill<br />
|
||||
Tim Martin<br />
|
||||
Kevin Siu<br />
|
||||
Markus Tavenrath<br />
|
||||
John Tran<br />
|
||||
Vicki Wang<br />
|
||||
Junkai Wu<br />
|
||||
Fung Xie<br />
|
||||
Albert Xu<br />
|
||||
Yang Xu<br />
|
||||
Jack Yang<br />
|
||||
Scott Yokim<br />
|
||||
Xiuxia Zhang<br />
|
||||
Nick Zhao<br />
|
||||
|
||||
## ACKNOWLEDGEMENTS
|
||||
|
||||
Girish Bharambe<br />
|
||||
Luke Durant<br />
|
||||
Carter Edwards<br />
|
||||
|
||||
133
CUDA.cmake
133
CUDA.cmake
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -26,49 +26,46 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CUTLASS_NATIVE_CUDA_INIT ON)
|
||||
elseif(CMAKE_VERSION VERSION_LESS 3.12.4)
|
||||
set(CUTLASS_NATIVE_CUDA_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_NATIVE_CUDA_INIT ON)
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
message(WARNING "CUDA_COMPILER flag is deprecated, set CMAKE_CUDA_COMPILER to desired compiler executable.")
|
||||
set(__CLANG_DEVICE_COMPILATION_REQUESTED ON)
|
||||
elseif(CUDA_COMPILER)
|
||||
message(WARNING "Deprecated flag CUDA_COMPILER used with unknown argument ${CUDA_COMPILER}, ignoring.")
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NATIVE_CUDA ${CUTLASS_NATIVE_CUDA_INIT} CACHE BOOL "Utilize the CMake native CUDA flow")
|
||||
|
||||
if(NOT DEFINED ENV{CUDACXX} AND NOT DEFINED ENV{CUDA_BIN_PATH} AND DEFINED ENV{CUDA_PATH})
|
||||
# For backward compatibility, allow use of CUDA_PATH.
|
||||
set(ENV{CUDACXX} $ENV{CUDA_PATH}/bin/nvcc)
|
||||
if (__CLANG_DEVICE_COMPILATION_REQUESTED AND NOT DEFINED CMAKE_CUDA_COMPILER)
|
||||
set(CMAKE_CUDA_COMPILER clang++) # We will let the system find Clang or error out
|
||||
endif()
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA)
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
if(NOT CUDA_VERSION)
|
||||
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
|
||||
endif()
|
||||
if(NOT CUDA_VERSION)
|
||||
# For backward compatibility with older CMake code.
|
||||
set(CUDA_VERSION ${CUDAToolkit_VERSION})
|
||||
set(CUDA_VERSION_MAJOR ${CUDAToolkit_VERSION_MAJOR})
|
||||
set(CUDA_VERSION_MINOR ${CUDAToolkit_VERSION_MINOR})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
# In some scenarios, such as clang device compilation, the toolkit root may not be set, so we
|
||||
# force it here to the nvcc we found via the CUDAToolkit package.
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDAToolkit_NVCC_EXECUTABLE}/../.." ABSOLUTE)
|
||||
endif()
|
||||
|
||||
if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])")
|
||||
set(CUTLASS_NVCC_DEVICE_COMPILE ON CACHE BOOL "Using nvcc tools for device compilation")
|
||||
elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang")
|
||||
set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation")
|
||||
else()
|
||||
message(FATAL_ERROR "Uknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.")
|
||||
endif()
|
||||
|
||||
find_package(CUDA REQUIRED)
|
||||
# We workaround missing variables with the native flow by also finding the CUDA toolkit the old way.
|
||||
|
||||
if(NOT CMAKE_CUDA_COMPILER_VERSION)
|
||||
set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30")
|
||||
message(FATAL_ERROR "Clang device compilation for CUTLASS requires CMake 3.30 or higher.")
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 9.2)
|
||||
message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.")
|
||||
endif()
|
||||
if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc)
|
||||
message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}")
|
||||
message(FATAL_ERROR "CUDA 9.2+ required, found ${CUDA_VERSION}.")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
@ -76,6 +73,7 @@ find_library(
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x86_64-linux-gnu
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
@ -120,6 +118,7 @@ find_library(
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x86_64-linux-gnu
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
@ -209,16 +208,6 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
||||
# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
|
||||
# paths by default, so we add it explicitly here.
|
||||
|
||||
function(cutlass_correct_source_file_language_property)
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
foreach(File ${ARGN})
|
||||
if(File MATCHES ".*\.cu$")
|
||||
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
|
||||
else()
|
||||
@ -226,7 +215,14 @@ else()
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
|
||||
|
||||
if (MSVC)
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 8)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 16)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT} CACHE STRING "Batch size for unified source files")
|
||||
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
@ -239,11 +235,15 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED __BATCH_SOURCES)
|
||||
set(__BATCH_SOURCES ON)
|
||||
endif()
|
||||
|
||||
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
|
||||
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
|
||||
if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1)
|
||||
|
||||
set(CUDA_FILE_ARGS)
|
||||
set(TARGET_SOURCE_ARGS)
|
||||
@ -293,18 +293,13 @@ function(cutlass_add_library NAME)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
endif()
|
||||
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
|
||||
if (NOT __SKIP_GENCODE_FLAGS)
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
endif()
|
||||
|
||||
target_compile_features(
|
||||
@ -313,6 +308,14 @@ function(cutlass_add_library NAME)
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
get_target_property(TARGET_TYPE ${NAME} TYPE)
|
||||
|
||||
if (TARGET_TYPE MATCHES "SHARED")
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Shared)
|
||||
elseif(TARGET_TYPE MATCHES "STATIC")
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Static)
|
||||
endif()
|
||||
|
||||
if(__EXPORT_NAME)
|
||||
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
|
||||
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
|
||||
@ -323,19 +326,22 @@ endfunction()
|
||||
function(cutlass_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(oneValueArgs CUDA_RUNTIME_LIBRARY)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED __CUDA_RUNTIME_LIBRARY)
|
||||
set(__CUDA_RUNTIME_LIBRARY Shared)
|
||||
endif()
|
||||
|
||||
set(__CUDA_RUNTIME_LIBRARY_ALLOWED None Shared Static)
|
||||
if (NOT __CUDA_RUNTIME_LIBRARY IN_LIST __CUDA_RUNTIME_LIBRARY_ALLOWED)
|
||||
message(FATAL_ERROR "CUDA_RUNTIME_LIBRARY value '${__CUDA_RUNTIME_LIBRARY}' is not in allowed list of '${__CUDA_RUNTIME_LIBRARY_ALLOWED}'")
|
||||
endif()
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
endif()
|
||||
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
@ -346,6 +352,8 @@ function(cutlass_add_executable NAME)
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY ${__CUDA_RUNTIME_LIBRARY})
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_target_sources NAME)
|
||||
@ -356,7 +364,6 @@ function(cutlass_target_sources NAME)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
target_sources(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
endfunction()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,13 +1,49 @@
|
||||
# Publications Using Cutlass
|
||||
|
||||
## 2025
|
||||
|
||||
- ["ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization"](https://arxiv.org/abs/2502.02631). Zechun Liu, Changsheng Zhao, Hanxian Huang, Sijia Chen, Jing Zhang, Jiawei Zhao, Scott Roy, Lisa Jin, Yunyang Xiong, Yangyang Shi, Lin Xiao, Yuandong Tian, Bilge Soran, Raghuraman Krishnamoorthi, Tijmen Blankevoort, Vikas Chandra. _arXiv_, February 2025.
|
||||
|
||||
## 2024
|
||||
|
||||
- ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024.
|
||||
|
||||
- ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024.
|
||||
|
||||
- ["EVT: Accelerating Deep Learning Training with Epilogue Visitor Tree"](https://dl.acm.org/doi/10.1145/3620666.3651369). Zhaodong Chen, Andrew Kerr, Richard Cai, Jack Kosaian, Haicheng Wu, Yufei Ding, and Yuan Xie. _Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, April 2024.
|
||||
|
||||
- ["Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level"](https://arxiv.org/abs/2403.04690). Ali Hassani, Wen-Mei Hwu, Humphrey Shi. _arXiv_, March 2024.
|
||||
|
||||
## 2023
|
||||
|
||||
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
|
||||
|
||||
- ["Benchmarking GPU Tensor Cores on General Matrix Multiplication Kernels through CUTLASS"](https://www.mdpi.com/2076-3417/13/24/13022). Xuanteng Huang, Xianwei Zhang, Panfei Yang, Nong Xiao. _Journal of Applied Sciences_, December 2023.
|
||||
|
||||
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.
|
||||
|
||||
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
|
||||
|
||||
- ["MegaBlocks: Efficient Sparse Training with Mixture-of-Experts"](https://arxiv.org/abs/2211.15841). Trevor Gale, Deepak Narayanan, Cliff Young, Matei Zaharia. _Proceedings of the Sixth Machine Learning and Systems_, May 2023.
|
||||
|
||||
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
|
||||
|
||||
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
|
||||
|
||||
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
||||
|
||||
- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023.
|
||||
|
||||
- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, Feburary 2023.
|
||||
|
||||
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
||||
|
||||
## 2022
|
||||
|
||||
- ["GPU Load Balancing"](https://arxiv.org/abs/2212.08964). Muhammad Osama. _Doctoral dissertation, University of California, Davis_, December 2022.
|
||||
|
||||
- ["Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"](https://arxiv.org/abs/2211.10017). Young Jin Kim, Rawn Henry, Raffy Fahim, Hany Hassan Awadalla. _Proceedings of the Third Workshop on Simple and Efficient Natural Language Processing_, December 2022.
|
||||
|
||||
- ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022.
|
||||
|
||||
- ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022.
|
||||
@ -18,7 +54,7 @@
|
||||
|
||||
- ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021.
|
||||
|
||||
- ["Real-time Neural Radiance Caching for Path Tracing"](https://d1qx31qr3h6wln.cloudfront.net/publications/paper_4.pdf). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
|
||||
- ["Real-time Neural Radiance Caching for Path Tracing"](https://dl.acm.org/doi/abs/10.1145/3450626.3459812). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
|
||||
|
||||
## 2020
|
||||
|
||||
|
||||
299
README.md
299
README.md
@ -1,8 +1,8 @@
|
||||

|
||||

|
||||
|
||||
# CUTLASS 3.0
|
||||
# CUTLASS 3.8.0
|
||||
|
||||
_CUTLASS 3.0 - January 2023_
|
||||
_CUTLASS 3.8.0 - January 2025_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -16,93 +16,149 @@ as building blocks within custom kernels and applications.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32),
|
||||
[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
multiply-accumulate abstractions for FP64, FP32, TF32, FP16, BF16,
|
||||
[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
8b floating point types (e5m2 and e4m3),
|
||||
block scaled data types (NVIDIA NVFP4 and OCP standard MXFP4, MXFP6, MXFP8),
|
||||
narrow integer types (4 and 8b signed and unsigned integers),
|
||||
and binary 1b data types (where architectures allow for the
|
||||
native support of such data types).
|
||||
CUTLASS demonstrates optimal matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, Ampere, and Hopper architectures.
|
||||
NVIDIA's Volta, Turing, Ampere, Ada, Hopper, and Blackwell architectures.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via
|
||||
the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution
|
||||
operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
|
||||
See the [functionality docs](./media/docs/functionality.md) for a more comprehensive
|
||||
list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU
|
||||
architecture.
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
# What's New in CUTLASS 3.8
|
||||
|
||||
CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
|
||||
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
|
||||
CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture.
|
||||
For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8.
|
||||
|
||||
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
|
||||
* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
|
||||
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
|
||||
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
|
||||
- Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
|
||||
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
|
||||
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
|
||||
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
|
||||
* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
|
||||
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
|
||||
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
|
||||
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
|
||||
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
|
||||
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
|
||||
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
|
||||
* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
|
||||
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
|
||||
+ Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
|
||||
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
|
||||
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
|
||||
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
|
||||
* [Non-block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp)
|
||||
* [Non-block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp)
|
||||
* [Block scaled data types without support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp)
|
||||
* [Block scaled data types with support for pointer array and grouped GEMM with TMA](./include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp)
|
||||
- Blackwell [collective mainloop for convolution kernels](./include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp) supporting non-block scaled data types for fprop, dgrad, and wgrad.
|
||||
- New [GEMM](./include/cutlass/gemm/dispatch_policy.hpp), [convolution](./include/cutlass/conv/dispatch_policy.hpp), and [epilogue](./include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders.
|
||||
- [Blackwell epilogue that supports loading accumulators from `tmem`](./include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp) and [full set of EVT fusions]().
|
||||
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
|
||||
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
|
||||
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
|
||||
- Support for mixed input GEMM kernels on Hopper in the profiler.
|
||||
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
|
||||
* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
|
||||
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
|
||||
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
|
||||
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
|
||||
+ [NVFP4 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu)
|
||||
+ [NVFP4 inputs with NVFP4 output](./examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu)
|
||||
+ [Mixed MXFP8 and MXFP6 inputs with BF16 output](./examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu)
|
||||
- GEMM example demonstrating [Blackwell's new preferred cluster support via dynamic cluster shapes](./examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for increased occupancy.
|
||||
- [GEMM with CLC based StreamK scheduler for load balancing](./examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu).
|
||||
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
|
||||
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
|
||||
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
|
||||
- A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
|
||||
* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
|
||||
- A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
|
||||
- A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
|
||||
* Documentation updates:
|
||||
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
|
||||
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
|
||||
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
|
||||
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
|
||||
|
||||
# What's New in CUTLASS 3.0
|
||||
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture.
|
||||
|
||||
- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md)
|
||||
- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md).
|
||||
- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe.
|
||||
- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe.
|
||||
- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs.
|
||||
- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA.
|
||||
- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality.
|
||||
- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu).
|
||||
- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format
|
||||
for kernel procedural names.
|
||||
- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release.
|
||||
|
||||
## New architecture, compiler, and CUDA Toolkit requirements
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS 3.0 *removes support* for the following:
|
||||
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
- C++ language versions less than 17.
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.0-gemm-peak-performance.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit peak performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture),
|
||||
an [NVIDIA L40](https://www.nvidia.com/en-us/data-center/l40/) (NVIDIA Ada architecture),
|
||||
an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere architecture),
|
||||
and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture).
|
||||
CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
they exhibit nearly optimal utilization of peak theoretical throughput. The figure below
|
||||
shows CUTLASS 3.8's performance as a % of theoretical peak utilization
|
||||
on various input and output data types when run on NVIDIA Blackwell SM100 architecture GPU.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg></p>
|
||||
|
||||
The two figures below show the continual CUTLASS performance improvements
|
||||
on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since
|
||||
CUTLASS 3.1.
|
||||
CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads).
|
||||
Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
|
||||
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
|
||||
|
||||
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance.png></p>
|
||||
<p align="center"><img src=media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png></p>
|
||||
|
||||
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
|
||||
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
|
||||
as shown in the above figure. Tensor Core operations are still implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
# CuTe
|
||||
|
||||
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
|
||||
CuTe is a collection of C++ CUDA template abstractions for
|
||||
defining and operating on hierarchically multidimensional layouts of threads and data.
|
||||
CuTe provides `Layout` and `Tensor` objects that compactly package the type,
|
||||
shape, memory space, and layout of data, while performing the complicated indexing for the user.
|
||||
This lets programmers focus on the logical descriptions of their algorithms while
|
||||
CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design,
|
||||
implement, and modify all dense linear algebra operations.
|
||||
|
||||
The core abstractions of CuTe are hierarchically multidimensional layouts
|
||||
which can be composed with data arrays to represent tensors.
|
||||
The representation of layouts is powerful enough to represent nearly
|
||||
everything we need to implement efficient dense linear algebra.
|
||||
Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
|
||||
|
||||
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates.
|
||||
This greatly simplifies the design and improves code composability and readability.
|
||||
More documentation specific to CuTe can be found in its
|
||||
[dedicated documentation directory](./media/docs/cute/00_quickstart.md).
|
||||
|
||||
# Compatibility
|
||||
|
||||
Minimum requirements:
|
||||
|
||||
- Architecture: Volta (compute capability 7.0)
|
||||
- Compiler: Must support at least C++17
|
||||
- CUDA Toolkit version: 11.4
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8.
|
||||
performs best when built with the [**CUDA 12.8 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and all other CUDA 12.x versions.
|
||||
|
||||
## Operating Systems
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
@ -111,66 +167,99 @@ We have tested the following environments.
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 22.04 | GCC 11.2.0 |
|
||||
|
||||
Note: We plan to add Windows (MSVC) & Clang compiler support soon.
|
||||
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
|
||||
|
||||
Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits.
|
||||
CUTLASS team is working on a fix.
|
||||
|
||||
## Hardware
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**|
|
||||
|---|---|---|
|
||||
|NVIDIA V100 Tensor Core GPU |7.0|11.4|
|
||||
|NVIDIA TitanV |7.0|11.4|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4|
|
||||
|NVIDIA GeForce RTX 20x0 series |7.5|11.4|
|
||||
|NVIDIA T4 |7.5|11.4|
|
||||
|NVIDIA A100 Tensor Core GPU |8.0|11.4|
|
||||
|NVIDIA A10 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 3090 |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 4090 |8.9|11.8|
|
||||
|NVIDIA GeForce RTX 30x0 series |8.6|11.4|
|
||||
|NVIDIA GeForce RTX 40x0 series |8.9|11.8|
|
||||
|NVIDIA L40 |8.9|11.8|
|
||||
|NVIDIA H100 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA H200 Tensor Core GPU |9.0|11.8|
|
||||
|NVIDIA B200 Tensor Core GPU |10.0|12.8|
|
||||
|
||||
## Target Architecture
|
||||
|
||||
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
In general, PTX code generated for one target architecture can be run on future architectures
|
||||
(i.e., it is forward compatible).
|
||||
However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose
|
||||
PTX does not have forward compatibility guarantees.
|
||||
Several Hopper and Blackwell PTX instructions fall under this category of
|
||||
architecture-accelerated features, and thus require a `sm_90a` or `sm100a` target architecture
|
||||
(note the "a" appended). For more details on this and other architecture-accelerated instructions,
|
||||
please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
|
||||
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error.
|
||||
The target architecture information is passed on to CUTLASS via the cmake flag
|
||||
`CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100,
|
||||
users are required to build CUTLASS with `90a` as the target architecture.
|
||||
If a user accidentally builds a kernel which uses SM90a features
|
||||
(e.g. Hopper Tensor Core Instructions), using the SM90 target
|
||||
(note the lack of "a"), with either CUDA Toolkit 12 or 11.8,
|
||||
the kernel is expected to fail with a runtime error.
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
|
||||
```
|
||||
Or
|
||||
|
||||
```
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS="100a"
|
||||
```
|
||||
|
||||
Please refer to the [functionality documentation](media/docs/functionality.md) for details on which kernels require which target architectures.
|
||||
Note: The NVIDIA Blackwell SM100 architecture used in the datacenter
|
||||
products has a different compute capability than the one underpinning
|
||||
NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels
|
||||
compiled for Blackwell SM100 architecture with arch conditional features
|
||||
(using `sm100a`) are not compatible with RTX 50 series GPUs.
|
||||
|
||||
Please refer to the [functionality documentation](./media/docs/functionality.md)
|
||||
for details on which kernels require which target architectures.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development
|
||||
- [Quick Start Guide](./media/docs/quickstart.md) - basics of building and running CUTLASS
|
||||
- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components
|
||||
- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts
|
||||
- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](./media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilitate rapid development
|
||||
- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent
|
||||
kernels in the same stream, and how it is used in CUTLASS.
|
||||
|
||||
# Resources
|
||||
We have also described the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
- [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/)
|
||||
- [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/)
|
||||
- [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/)
|
||||
- [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/)
|
||||
- [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/)
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
@ -178,7 +267,8 @@ CUTLASS is a header-only template library and does not need to be built to be us
|
||||
projects. Client applications should target CUTLASS's `include/` directory in their include
|
||||
paths.
|
||||
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12.
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake.
|
||||
The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
@ -223,7 +313,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](media/docs/code_organization.md), but several main components are summarized below.
|
||||
[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
@ -272,7 +362,7 @@ include/ # client applications should target this directory
|
||||
|
||||
### CUTLASS SDK Examples
|
||||
|
||||
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
|
||||
[CUTLASS SDK examples](./examples) apply CUTLASS templates to implement basic computations.
|
||||
|
||||
### Tools
|
||||
|
||||
@ -297,7 +387,7 @@ tools/
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](media/docs/quickstart.md).
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
@ -513,9 +603,9 @@ reference_device: Passed
|
||||
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM conovlution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
- [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md)
|
||||
|
||||
|
||||
# About
|
||||
@ -529,7 +619,7 @@ The official list of CUTLASS developers and contributors is available here: [CON
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
@ -558,4 +648,3 @@ SPDX-License-Identifier: BSD-3-Clause
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
@ -1,3 +1,31 @@
|
||||
# Copyright (c) 2019 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# A small utility function which generates a C-header from an input file
|
||||
function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
|
||||
FILE(READ "${FILENAME}" HEX_INPUT HEX)
|
||||
@ -6,7 +34,7 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
|
||||
endif()
|
||||
|
||||
string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT})
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "char(0x\\1)," HEX_OUTPUT ${HEX_OUTPUT})
|
||||
|
||||
set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")
|
||||
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
# Generated file
|
||||
|
||||
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
|
||||
if (NOT "@TEST_EXE_DIR@" STREQUAL "")
|
||||
set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@)
|
||||
else()
|
||||
set(TEST_EXE_PATH @TEST_EXE@)
|
||||
endif()
|
||||
|
||||
add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
|
||||
if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "")
|
||||
set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
52
cmake/CTestTestfile.configure.cmake
Normal file
52
cmake/CTestTestfile.configure.cmake
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Generated file
|
||||
|
||||
set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)
|
||||
|
||||
if (NOT DEFINED ENV{CUTLASS_TEST_SETS})
|
||||
set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@)
|
||||
endif()
|
||||
|
||||
foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS})
|
||||
if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED)
|
||||
message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].")
|
||||
return()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(TEST_EXE_PATH @TEST_EXE_PATH@)
|
||||
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
|
||||
set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@)
|
||||
|
||||
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
43
cmake/CTestTestfile.test.configure.cmake
Normal file
43
cmake/CTestTestfile.test.configure.cmake
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT)
|
||||
# The longform/extended format allows generator expressions to be
|
||||
# expanded property and is useful in contexts where the files need
|
||||
# to be immediately included into being-processed cmake code.
|
||||
add_test(NAME @TESTCASE_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
else()
|
||||
add_test(@TESTCASE_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
endif()
|
||||
|
||||
if (TEST_EXE_WORKING_DIRECTORY)
|
||||
set_tests_properties(@TESTCASE_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TESTCASE_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
|
||||
@ -2,6 +2,8 @@ get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
if(NOT TARGET nvidia::cutlass::CUTLASS)
|
||||
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
|
||||
if(TARGET nvidia::cutlass::CUTLASS)
|
||||
return()
|
||||
endif()
|
||||
|
||||
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
|
||||
@ -1,3 +1,31 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
set(CPACK_PACKAGE_NAME NvidiaCutlass)
|
||||
set(CPACK_PACKAGE_VENDOR NVIDIA)
|
||||
set(CPACK_PACKAGE_CONTACT info@nvidia.com)
|
||||
|
||||
@ -1,3 +1,31 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
|
||||
@ -6,10 +34,11 @@ if(GOOGLETEST_DIR)
|
||||
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
|
||||
endif()
|
||||
|
||||
set(GTEST_REPOSITORY "https://github.com/google/googletest.git" CACHE STRING "GoogleTest repo to fetch")
|
||||
FetchContent_Declare(
|
||||
googletest
|
||||
GIT_REPOSITORY https://github.com/google/googletest.git
|
||||
GIT_TAG 0fe9660
|
||||
GIT_REPOSITORY ${GTEST_REPOSITORY}
|
||||
GIT_TAG v1.14.0
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(googletest)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@
|
||||
#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@
|
||||
#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@
|
||||
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
inline uint32_t getVersion() {
|
||||
return CUTLASS_VERSION;
|
||||
}
|
||||
inline uint32_t getVersionMajor() {
|
||||
return CUTLASS_MAJOR;
|
||||
}
|
||||
inline uint32_t getVersionMinor() {
|
||||
return CUTLASS_MINOR;
|
||||
}
|
||||
inline uint32_t getVersionPatch() {
|
||||
return CUTLASS_PATCH;
|
||||
}
|
||||
inline uint32_t getVersionBuild() {
|
||||
return CUTLASS_BUILD + 0;
|
||||
}
|
||||
inline std::string getVersionString() {
|
||||
std::string version = "@CUTLASS_VERSION@";
|
||||
if (getVersionBuild()) {
|
||||
version += "." + std::to_string(getVersionBuild());
|
||||
}
|
||||
return version;
|
||||
}
|
||||
inline std::string getGitRevision() {
|
||||
return "@CUTLASS_REVISION@";
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -30,14 +30,5 @@
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <vector_types.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace cute {
|
||||
|
||||
using cutlass::float_e4m3_t;
|
||||
using cutlass::float_e5m2_t;
|
||||
|
||||
} // end namespace cute
|
||||
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
|
||||
#define CUTLASS_REVISION "@CUTLASS_REVISION@"
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
92
customConfigs.cmake
Normal file
92
customConfigs.cmake
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Profiler based functional testing
|
||||
set(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS OFF CACHE BOOL "Utilize profiler-based functional regressions")
|
||||
set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "Profiler functional regression test level")
|
||||
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
function(cutlass_generate_kernel_filter_and_testlists_files)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs TEST_SET_NAME)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR}
|
||||
${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py
|
||||
--generator-target=${__TEST_SET_NAME}
|
||||
--cuda-version=${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}
|
||||
--architectures=${CUTLASS_NVCC_ARCHS}
|
||||
--kernels=\*
|
||||
--disable-cutlass-package-imports
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
|
||||
RESULT_VARIABLE cutlass_FILTER_GENERATION_RESULT
|
||||
OUTPUT_VARIABLE cutlass_FILTER_GENERATION_OUTPUT
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log
|
||||
)
|
||||
|
||||
if(NOT cutlass_FILTER_GENERATION_RESULT EQUAL 0)
|
||||
message(FATAL_ERROR "Error generating kernel filters and testlists files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
|
||||
|
||||
set(PROFILER_ARCH_LIST 100a)
|
||||
foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS)
|
||||
if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST))
|
||||
message(FATAL_ERROR "Only SM100a compute capability is supported with profiler-based unit tests")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 0)
|
||||
|
||||
message(STATUS "Building for L0 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l0)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
elseif (CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 1)
|
||||
|
||||
message(STATUS "Building for L1 profiler-based functional regressions")
|
||||
cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l1)
|
||||
set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set")
|
||||
set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set")
|
||||
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
@ -1677,7 +1677,7 @@ template<typename Element , typename Layout > </div>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
<p>Returns a pair containing a boolean of whether a value exists in a tensor and the location of of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined. </p>
|
||||
<p>Returns a pair containing a boolean of whether a value exists in a tensor and the location of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined. </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -31,4 +31,5 @@
|
||||
cutlass_example_add_executable(
|
||||
02_dump_reg_shmem
|
||||
dump_reg_shmem.cu
|
||||
DISABLE_TESTS ON
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -260,7 +260,7 @@ private:
|
||||
if (options.vectorize <= 2) return std::make_pair(false, -1);
|
||||
|
||||
// Boundary check.
|
||||
if (i > elements.size() || (i + options.vectorize - 1) > elements.size())
|
||||
if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size()))
|
||||
return std::make_pair(false, -1);
|
||||
|
||||
// Check if either all elements are valid or invalid.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -94,7 +94,7 @@ __global__ void copy(
|
||||
|
||||
typename Iterator::Fragment fragment;
|
||||
|
||||
for(int i = 0; i < fragment.size(); ++i) {
|
||||
for(size_t i = 0; i < fragment.size(); ++i) {
|
||||
fragment[i] = 0;
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -207,15 +207,15 @@ cudaError_t strided_batched_gemm_nn_reference(
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
|
||||
if (A.size() < lda * k * batch_count) {
|
||||
if (A.size() < size_t(lda * k * batch_count)) {
|
||||
std::cout << "the size of A is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
if (B.size() < ldb * n) {
|
||||
if (B.size() < size_t(ldb * n)) {
|
||||
std::cout << "the size of B is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
if (C.size() < ldc * n * batch_count) {
|
||||
if (C.size() < size_t(ldc * n * batch_count)) {
|
||||
std::cout << "the size of C is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -162,7 +162,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes ?
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -140,8 +140,8 @@ using ElementInputA = int8_t; // <- data type of elements
|
||||
using ElementInputB = int8_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = int32_t; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
// The code section below describes matrix layout of input and output matrices. Row Major for
|
||||
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
@ -161,7 +161,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M =
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
@ -355,4 +355,3 @@ int main() {
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -143,7 +143,6 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = int32_t; // Data type of accumulator
|
||||
@ -555,6 +554,7 @@ Result profile_convolution(Options const &options) {
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
ElementOutput,
|
||||
cutlass::NumericConverterClamp<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
@ -674,7 +674,6 @@ Result profile_convolution(Options const &options) {
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
@ -761,11 +760,7 @@ int main(int argc, char const **args) {
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -27,7 +27,10 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
if (CUTLASS_ENABLE_LIBRARY)
|
||||
|
||||
# Planar Complex GEMM example
|
||||
cutlass_example_add_executable(
|
||||
@ -35,11 +38,6 @@ cutlass_example_add_executable(
|
||||
planar_complex.cu
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
|
||||
target_link_libraries(
|
||||
10_planar_complex
|
||||
PRIVATE
|
||||
@ -48,3 +46,4 @@ target_link_libraries(
|
||||
cuda
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -27,7 +27,10 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
if (CUTLASS_ENABLE_LIBRARY)
|
||||
|
||||
# Planar Complex Array GEMM example
|
||||
cutlass_example_add_executable(
|
||||
@ -35,11 +38,6 @@ cutlass_example_add_executable(
|
||||
planar_complex_array.cu
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# This example depends on the CUTLASS Library
|
||||
#
|
||||
|
||||
target_link_libraries(
|
||||
11_planar_complex_array
|
||||
PRIVATE
|
||||
@ -48,3 +46,4 @@ target_link_libraries(
|
||||
cuda
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -81,10 +81,10 @@ using ShapeMMAThreadBlock =
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
|
||||
//
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
@ -64,6 +64,7 @@ endforeach()
|
||||
foreach(FUSION_GEMM_EXAMPLE
|
||||
fused_two_gemms_f16_sm75_rf
|
||||
fused_two_gemms_f16_sm75_shmem
|
||||
fused_two_gemms_grouped_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_shmem
|
||||
fused_two_gemms_s8_sm75_rf
|
||||
@ -79,4 +80,3 @@ foreach(FUSION_GEMM_EXAMPLE
|
||||
add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
# Introduction
|
||||
|
||||
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
|
||||
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_fusion.png></p>
|
||||
|
||||
When running two unfused GEMM/Conv operations, each operation loads one input
|
||||
activation matrix, one weight matrix (or filter matrix) from the memory and then
|
||||
When running two unfused GEMM/Conv operations, each operation loads one input
|
||||
activation matrix, one weight matrix (or filter matrix) from the memory and then
|
||||
stores the result activation matrix back to the memory.
|
||||
|
||||
When the two GEMM/Conv operations are fused together, the mainloops of the two
|
||||
@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same
|
||||
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
|
||||
2 GEMMs/Convs.
|
||||
|
||||
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
|
||||
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
|
||||
input activation, the example enforces the following two constraints:
|
||||
|
||||
- thread_block_tile_N = problem_N
|
||||
- thread_block_tile_N = problem_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_block_resident_fusion.png></p>
|
||||
|
||||
@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o
|
||||
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
|
||||
operation can be fully block-resident.
|
||||
|
||||
- warp_tile_N = thread_block_tile_N
|
||||
- warp_tile_N = thread_block_tile_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_rf_resident_fusion.png></p>
|
||||
|
||||
@ -82,11 +82,11 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
|
||||
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -42,6 +42,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun
|
||||
//
|
||||
|
||||
B2bNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
@ -399,9 +400,9 @@ struct B2bFusedGemmRun
|
||||
//
|
||||
|
||||
B2bFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
@ -412,7 +413,7 @@ struct B2bFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -420,11 +421,11 @@ struct B2bFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@ -453,70 +454,90 @@ struct B2bFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
// batch_count is used as split-k when mode is kGemm according
|
||||
// to the GemmUniversal interface
|
||||
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A0 = 0,
|
||||
int64_t batch_stride_B0 = 0,
|
||||
int64_t batch_stride_C0 = 0,
|
||||
int64_t batch_stride_B1 = 0,
|
||||
int64_t batch_stride_C1 = 0,
|
||||
int64_t batch_stride_D1 = 0,
|
||||
int64_t batch_stride_Bias0 = 0,
|
||||
int64_t batch_stride_Scale0 = 0,
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
mode,
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0,
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
batch_count,
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator
|
||||
>(
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_A0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_C0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
tensor_Bias0.device_ref(),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_C0,
|
||||
batch_stride_Scale0,
|
||||
batch_stride_Bias0
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
alpha1, //intermediate alpha=1
|
||||
reference_D0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B1.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
beta1, //beta = 0
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
reference_D1.device_ref(),
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -680,7 +729,7 @@ struct B2bFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -694,7 +743,7 @@ struct B2bFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
|
||||
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Containers for running grouped back-to-back GEMMs
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename B2bGemm_>
|
||||
struct B2bFusedGroupedGemmRun
|
||||
{
|
||||
|
||||
using B2bGemm = B2bGemm_;
|
||||
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
||||
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
B2bFusedGroupedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 1, -1, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
|
||||
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
|
||||
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
|
||||
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
|
||||
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||
|
||||
int problem_count = (int)problem_sizes_0.size();
|
||||
|
||||
std::vector<HostTensorA> host_tensor_A0(problem_count);
|
||||
std::vector<HostTensorB> host_tensor_B0(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_C0(problem_count);
|
||||
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
|
||||
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
|
||||
std::vector<HostTensorB> host_tensor_B1(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_C1(problem_count);
|
||||
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_D1(problem_count);
|
||||
std::vector<HostTensorZ> host_tensor_Z(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
|
||||
|
||||
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
|
||||
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
|
||||
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
|
||||
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
|
||||
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
|
||||
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
|
||||
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
auto problem_size_0 = problem_sizes_0[i];
|
||||
auto problem_size_1 = problem_sizes_1[i];
|
||||
|
||||
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
|
||||
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
|
||||
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
|
||||
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
|
||||
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
|
||||
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
|
||||
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
|
||||
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
|
||||
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_D1.at(i).host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_ref_D0.at(i).host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_ref_D1.at(i).host_view());
|
||||
|
||||
host_tensor_A0.at(i).sync_device();
|
||||
host_tensor_B0.at(i).sync_device();
|
||||
host_tensor_C0.at(i).sync_device();
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
host_tensor_Scale0.at(i).sync_device();
|
||||
host_tensor_Bias0.at(i).sync_device();
|
||||
host_tensor_B1.at(i).sync_device();
|
||||
host_tensor_C1.at(i).sync_device();
|
||||
host_tensor_Bias1.at(i).sync_device();
|
||||
host_tensor_D1.at(i).sync_device();
|
||||
host_tensor_ref_D0.at(i).sync_device();
|
||||
host_tensor_ref_D1.at(i).sync_device();
|
||||
|
||||
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
|
||||
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());
|
||||
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
|
||||
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
|
||||
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
|
||||
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
|
||||
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
|
||||
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
|
||||
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
|
||||
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
|
||||
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
|
||||
device_ref_A0.copy_from_host(ref_A0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
|
||||
device_ref_B0.copy_from_host(ref_B0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
|
||||
device_ref_C0.copy_from_host(ref_C0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
|
||||
device_ref_Scale0.copy_from_host(ref_Scale0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
|
||||
device_ref_Bias0.copy_from_host(ref_Bias0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
|
||||
device_ref_B1.copy_from_host(ref_B1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
|
||||
device_ref_C1.copy_from_host(ref_C1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
|
||||
device_ref_Bias1.copy_from_host(ref_Bias1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
|
||||
device_ref_D1.copy_from_host(ref_D1.data());
|
||||
|
||||
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
|
||||
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
|
||||
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
|
||||
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
|
||||
if (!threadblock_count) {
|
||||
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
problem_count,
|
||||
device_problem_sizes_0.get(),
|
||||
device_problem_sizes_1.get(),
|
||||
device_ref_A0.get(),
|
||||
device_ref_B0.get(),
|
||||
device_ref_C0.get(),
|
||||
device_ref_Scale0.get(),
|
||||
device_ref_Bias0.get(),
|
||||
device_ref_B1.get(),
|
||||
device_ref_C1.get(),
|
||||
device_ref_D1.get(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
threadblock_count
|
||||
};
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
host_tensor_D1.at(i).sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator>
|
||||
reference_gemm_1;
|
||||
|
||||
auto problem_size_0 = problem_sizes_0[i];
|
||||
auto problem_size_1 = problem_sizes_1[i];
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
ref_A0.at(i),
|
||||
ref_B0.at(i),
|
||||
ElementAccumulator(0), //beta = 0
|
||||
ref_Z.at(i),
|
||||
ref_Z.at(i),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutC
|
||||
> (
|
||||
problem_size_0,
|
||||
ref_Z.at(i),
|
||||
ref_ref_D0.at(i),
|
||||
alpha0,
|
||||
ref_Scale0.at(i),
|
||||
ref_Bias0.at(i)
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
ref_ref_D0.at(i),
|
||||
ref_B1.at(i),
|
||||
beta1,
|
||||
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
ref_ref_D1.at(i)
|
||||
);
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
host_tensor_ref_D0.at(i).sync_host();
|
||||
host_tensor_ref_D1.at(i).sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
host_tensor_ref_D1.at(i).host_view(),
|
||||
host_tensor_D1.at(i).host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bGemm_device_fused.txt";
|
||||
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "GEMM " << i << " in group\n"
|
||||
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
|
||||
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
|
||||
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
|
||||
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
|
||||
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
|
||||
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
|
||||
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
|
||||
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
|
||||
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
|
||||
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -43,6 +43,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
//
|
||||
|
||||
B2bInterleavedNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
//
|
||||
|
||||
B2bInterleavedFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
// batch_count is used as split-k when mode is kGemm according
|
||||
// to the GemmUniversal interface
|
||||
|
||||
int batch_count = 1,
|
||||
|
||||
int64_t batch_stride_A0 = 0,
|
||||
int64_t batch_stride_B0 = 0,
|
||||
int64_t batch_stride_C0 = 0,
|
||||
int64_t batch_stride_B1 = 0,
|
||||
int64_t batch_stride_C1 = 0,
|
||||
int64_t batch_stride_D1 = 0,
|
||||
int64_t batch_stride_Bias0 = 0,
|
||||
int64_t batch_stride_Scale0 = 0,
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<16>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
// tensor_Bias0_batched.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
mode,
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_B1_reordered.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0,
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
batch_count,
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator
|
||||
>(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_A0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_C0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
tensor_Bias0.device_ref(),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_C0,
|
||||
batch_stride_Scale0,
|
||||
batch_stride_Bias0
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
alpha1, //intermediate alpha=1
|
||||
reference_D0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B1.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
beta1, //beta = 0
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
reference_D1.device_ref(),
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -119,8 +119,6 @@ template <
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
@ -154,7 +152,6 @@ class B2bGemm {
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
@ -184,77 +181,11 @@ class B2bGemm {
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
SmemAccumulator
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
typename EpilogueOutputOp0::Params epilogue0;
|
||||
typename EpilogueOutputOp1::Params epilogue1;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||
typename EpilogueOutputOp0::Params(),
|
||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||
typename EpilogueOutputOp1::Params(),
|
||||
int split_k_slices_ = 1
|
||||
):
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
split_k_slices(split_k_slices_) {
|
||||
|
||||
}
|
||||
};
|
||||
using Arguments = typename B2bGemmKernel::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
@ -269,10 +200,6 @@ public:
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
Status status = B2bGemmKernel::can_implement(
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
@ -295,20 +222,14 @@ public:
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
args.batch_count);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
@ -320,38 +241,17 @@ public:
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
args.batch_count);
|
||||
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
||||
// args.problem_size_1,
|
||||
// args.problem_size_1,
|
||||
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
||||
// args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
// args.batch_count);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename B2bGemmKernel::Params{
|
||||
args.mode,
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
grid_shape,
|
||||
@ -363,6 +263,13 @@ public:
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
args.batch_stride_A0,
|
||||
args.batch_stride_B0,
|
||||
args.batch_stride_B1,
|
||||
args.batch_stride_C1,
|
||||
args.batch_stride_D1,
|
||||
args.batch_stride_Bias0,
|
||||
args.batch_stride_Scale0,
|
||||
args.epilogue0,
|
||||
args.epilogue1,
|
||||
static_cast<int *>(workspace),
|
||||
@ -373,12 +280,6 @@ public:
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
||||
@ -430,12 +331,12 @@ public:
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -220,7 +220,6 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
@ -229,10 +228,6 @@ int main() {
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -39,7 +39,6 @@
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
@ -219,20 +218,13 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -0,0 +1,297 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/base_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
#include "b2b_grouped_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_0;
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_1;
|
||||
|
||||
// Constraints:
|
||||
// 1. Warp shape N must equal thread block shape N
|
||||
// 2. Problem size N must equal thread block shape N
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool reference_check;
|
||||
int alignment = 8;
|
||||
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
|
||||
|
||||
int problem_count;
|
||||
bool verbose;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
reference_check(true),
|
||||
problem_count(15),
|
||||
verbose(false)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("problems", problem_count, 15);
|
||||
cmd.get_cmd_line_argument("reference-check", reference_check, true);
|
||||
cmd.get_cmd_line_argument("verbose", verbose, false);
|
||||
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
|
||||
//
|
||||
// For now, randomly choose the problem sizes.
|
||||
//
|
||||
|
||||
int cmd_line_m = -1;
|
||||
int cmd_line_k = -1;
|
||||
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes0.reserve(problem_count);
|
||||
problem_sizes1.reserve(problem_count);
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
|
||||
int m = cmd_line_m;
|
||||
int k = cmd_line_k;
|
||||
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 256) + 1);
|
||||
}
|
||||
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 256) + 1);
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k);
|
||||
cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN);
|
||||
|
||||
problem_sizes0.push_back(problem0);
|
||||
problem_sizes1.push_back(problem1);
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
print_problem_sizes();
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
|
||||
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
|
||||
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
|
||||
<< " back-to-back GEMMs are subject to.s"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --problems=<int> Number of individual GEMM problems (default: --problems=15)\n"
|
||||
<< " --m=<int> Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n"
|
||||
<< " --k=<int> Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n"
|
||||
<< " --verbose=<bool> If true, prints problem sizes.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
|
||||
<< "# Runs a grouped B2b GEMM with 10 random problem sizes\n"
|
||||
<< "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void print_problem_sizes() {
|
||||
std::cout << std::endl;
|
||||
std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl;
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i);
|
||||
cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i);
|
||||
std::cout << "Problem " << i
|
||||
<< "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k()
|
||||
<< "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k()
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool run_fused_grouped_gemm_f16_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle<
|
||||
ThreadblockShape0,
|
||||
cutlass::layout::RowMajor // LayoutC
|
||||
>;
|
||||
|
||||
const int kAlignment = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
const int kStages = 3;
|
||||
using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
kAlignment,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
kAlignment,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
GroupedThreadblockSwizzle,
|
||||
kStages,
|
||||
cutlass::arch::OpMultiplyAdd
|
||||
>::B2bGemmKernel;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::BaseGrouped<B2bGemmKernel>;
|
||||
|
||||
B2bFusedGroupedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0;
|
||||
gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1;
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_fused_grouped_gemm_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "grouped gemm f16 RF residency");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -195,7 +195,6 @@ bool run_fused_gemm_s8_rf_res() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
@ -204,9 +203,6 @@ int main() {
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -43,7 +43,6 @@
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
@ -197,18 +196,13 @@ bool run_fused_gemm_s8_shmem() {
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm int8 shmem staing");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
bool passed = fusedGemm.run(
|
||||
gemm_s8_sm80_problem_size_0,
|
||||
gemm_s8_sm80_problem_size_1,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80_rf_res_batch() {
|
||||
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
int batch_count = 2;
|
||||
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
|
||||
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
|
||||
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
|
||||
int64_t batch_stride_Scale0 = 0;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(
|
||||
gemm_s8_sm80_problem_size_0,
|
||||
gemm_s8_sm80_problem_size_1,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
cutlass::gemm::GemmUniversalMode::kBatched,
|
||||
batch_count,
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_rf_res
|
||||
&run_fused_gemm_s8_sm80_rf_res,
|
||||
&run_fused_gemm_s8_sm80_rf_res_batch
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -40,19 +40,66 @@
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
|
||||
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
|
||||
/// the parameters of the problem visitor used in GroupedParams.
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_,
|
||||
typename Enable = void
|
||||
>
|
||||
struct ProblemVisitorOrDefault;
|
||||
|
||||
/// Return a generic problem visitor for GEMM problems
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_
|
||||
>
|
||||
struct ProblemVisitorOrDefault<B2bMma_,
|
||||
ThreadblockSwizzle_,
|
||||
typename platform::enable_if<
|
||||
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||
>::type> {
|
||||
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
|
||||
GroupScheduleMode::kDeviceOnly,
|
||||
128,
|
||||
128,
|
||||
platform::is_same<typename B2bMma_::LayoutC,
|
||||
cutlass::layout::ColumnMajor>::value>;
|
||||
};
|
||||
|
||||
/// Return the problem visitor specified by the swizzling function
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_
|
||||
>
|
||||
struct ProblemVisitorOrDefault<B2bMma_,
|
||||
ThreadblockSwizzle_,
|
||||
typename platform::enable_if<
|
||||
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||
>::type> {
|
||||
using value = typename ThreadblockSwizzle_::ProblemVisitor;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct B2bGemm {
|
||||
|
||||
@ -61,50 +108,225 @@ struct B2bGemm {
|
||||
using OutputOp0 = typename B2bMma::OutputOp;
|
||||
using OutputOp1 = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA0 = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB0 = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
|
||||
using ElementB1 = typename B2bMma::IteratorB1::Element;
|
||||
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||
|
||||
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
|
||||
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
|
||||
/// the second GEMM (other than for ElementA/ElementB)
|
||||
using ElementA = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB = typename B2bMma::IteratorB0::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = B2bMma::kTransformA;
|
||||
static ComplexTransform const kTransformB = B2bMma::kTransformB;
|
||||
using Operator = typename B2bMma::Operator0;
|
||||
|
||||
using OperatorClass = typename Operator::OperatorClass;
|
||||
using ThreadblockShape = typename B2bMma::Shape0;
|
||||
using WarpShape = typename Operator::Shape;
|
||||
using InstructionShape = typename Operator::InstructionShape;
|
||||
using ArchTag = typename B2bMma::ArchTag;
|
||||
|
||||
static int const kStages = B2bMma::kStages;
|
||||
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using Mma = B2bMma;
|
||||
using EpilogueOutputOp = OutputOp1;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||
static int const kThreadCount = 32 * WarpCount0::kCount;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename B2bMma::IteratorA0::Params params_A0;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::Params params_B0;
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::Params params_C0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||
typename B2bMma::IteratorB1::Params params_B1;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::Params params_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue::OutputTileIterator::Params params_D1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_size_0;
|
||||
int gemm_k_iterations_1;
|
||||
int gemm_k_size_1;
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
|
||||
GemmCoord problem_size_0{0,0,0};
|
||||
GemmCoord problem_size_1{0,0,0};
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0{};
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
|
||||
int64_t batch_stride_A0{0};
|
||||
int64_t batch_stride_B0{0};
|
||||
int64_t batch_stride_B1{0};
|
||||
int64_t batch_stride_C1{0};
|
||||
int64_t batch_stride_D1{0};
|
||||
int64_t batch_stride_Bias0{0};
|
||||
int64_t batch_stride_Scale0{0};
|
||||
typename OutputOp0::Params epilogue0 {};
|
||||
typename OutputOp1::Params epilogue1 {};
|
||||
int batch_count{1};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
Arguments() = default;
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
Arguments(
|
||||
GemmUniversalMode mode_,
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0_,
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
|
||||
int64_t batch_stride_A0_,
|
||||
int64_t batch_stride_B0_,
|
||||
int64_t batch_stride_B1_,
|
||||
int64_t batch_stride_C1_,
|
||||
int64_t batch_stride_D1_,
|
||||
int64_t batch_stride_Bias0_,
|
||||
int64_t batch_stride_Scale0_,
|
||||
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||
int batch_count_ = 1
|
||||
):
|
||||
mode(mode_),
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
batch_stride_A0(batch_stride_A0_),
|
||||
batch_stride_B0(batch_stride_B0_),
|
||||
batch_stride_B1(batch_stride_B1_),
|
||||
batch_stride_C1(batch_stride_C1_),
|
||||
batch_stride_D1(batch_stride_D1_),
|
||||
batch_stride_Bias0(batch_stride_Bias0_),
|
||||
batch_stride_Scale0(batch_stride_Scale0_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
batch_count(batch_count_) {
|
||||
}
|
||||
};
|
||||
|
||||
// Arguments structure for grouped B2B problems
|
||||
struct GroupedArguments {
|
||||
GemmCoord* problem_size_0;
|
||||
GemmCoord* problem_size_1;
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedArguments(
|
||||
int problem_count,
|
||||
GemmCoord* problem_size_0_,
|
||||
GemmCoord* problem_size_1_,
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
|
||||
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||
int threadblock_count = 0
|
||||
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count)
|
||||
{}
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
|
||||
cutlass::gemm::GemmCoord problem_size_0{};
|
||||
cutlass::gemm::GemmCoord problem_size_1{};
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape{};
|
||||
int swizzle_log_tile{0};
|
||||
typename B2bMma::IteratorA0::Params params_A0{};
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0{};
|
||||
typename B2bMma::IteratorB0::Params params_B0{};
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0{};
|
||||
typename Epilogue::OutputTileIterator::Params params_C0{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
|
||||
typename B2bMma::IteratorB1::Params params_B1{};
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1{};
|
||||
typename Epilogue::OutputTileIterator::Params params_C1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
|
||||
typename Epilogue::OutputTileIterator::Params params_D1{};
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
|
||||
typename OutputOp0::Params output_op_0{};
|
||||
typename OutputOp1::Params output_op_1{};
|
||||
int64_t batch_stride_A0{0};
|
||||
int64_t batch_stride_B0{0};
|
||||
int64_t batch_stride_B1{0};
|
||||
int64_t batch_stride_C1{0};
|
||||
int64_t batch_stride_D1{0};
|
||||
int64_t batch_stride_Bias0{0};
|
||||
int64_t batch_stride_Scale0{0};
|
||||
int *semaphore = nullptr;
|
||||
int gemm_k_iterations_0{0};
|
||||
int gemm_k_size_0{0};
|
||||
int gemm_k_iterations_1{0};
|
||||
int gemm_k_size_1{0};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmUniversalMode mode,
|
||||
cutlass::gemm::GemmCoord const & problem_size_0,
|
||||
cutlass::gemm::GemmCoord const & problem_size_1,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
@ -116,14 +338,22 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||
int64_t batch_stride_A0,
|
||||
int64_t batch_stride_B0,
|
||||
int64_t batch_stride_B1,
|
||||
int64_t batch_stride_C1,
|
||||
int64_t batch_stride_D1,
|
||||
int64_t batch_stride_Bias0,
|
||||
int64_t batch_stride_Scale0,
|
||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
|
||||
params_A0(ref_A0.layout()),
|
||||
ref_A0(ref_A0),
|
||||
params_B0(ref_B0.layout()),
|
||||
@ -138,6 +368,13 @@ struct B2bGemm {
|
||||
ref_C1(ref_C1),
|
||||
params_D1(ref_D1.layout()),
|
||||
ref_D1(ref_D1),
|
||||
batch_stride_A0(batch_stride_A0),
|
||||
batch_stride_B0(batch_stride_B0),
|
||||
batch_stride_B1(batch_stride_B1),
|
||||
batch_stride_C1(batch_stride_C1),
|
||||
batch_stride_D1(batch_stride_D1),
|
||||
batch_stride_Bias0(batch_stride_Bias0),
|
||||
batch_stride_Scale0(batch_stride_Scale0),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1) {
|
||||
|
||||
@ -152,6 +389,81 @@ struct B2bGemm {
|
||||
}
|
||||
};
|
||||
|
||||
struct GroupedParams {
|
||||
cutlass::gemm::GemmCoord* problem_size_0;
|
||||
cutlass::gemm::GemmCoord* problem_size_1;
|
||||
cutlass::gemm::GemmCoord* grid_tiled_shape;
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
|
||||
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
int* workspace;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedParams() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedParams(
|
||||
GroupedArguments const &args,
|
||||
void *workspace = nullptr,
|
||||
int tile_count = 0
|
||||
) :
|
||||
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
|
||||
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
|
||||
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
|
||||
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
|
||||
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
workspace(reinterpret_cast<int*>(workspace)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void transpose() {
|
||||
// Only row-major outputs are currently supported, so no transpose is performed
|
||||
}
|
||||
|
||||
/// Returns non-grouped paramaters to be used as input to the kernel-level
|
||||
/// operator for the problem indicated by problem_visitor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params to_single_params(const ProblemVisitor& problem_visitor) const {
|
||||
GemmCoord problem_size0 = problem_visitor.problem_size0();
|
||||
GemmCoord problem_size1 = problem_visitor.problem_size1();
|
||||
int32_t idx = problem_visitor.problem_index();
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
|
||||
|
||||
return Params(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size0,
|
||||
problem_size1,
|
||||
grid_shape,
|
||||
ref_A0[idx],
|
||||
ref_B0[idx],
|
||||
ref_C0[idx],
|
||||
ref_Scale0[idx],
|
||||
ref_Bias0[idx],
|
||||
ref_B1[idx],
|
||||
ref_C1[idx],
|
||||
ref_D1[idx],
|
||||
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
|
||||
output_op_0,
|
||||
output_op_1,
|
||||
workspace
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename B2bMma::B2bMmaSharedStorage main_loop;
|
||||
@ -163,7 +475,7 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bGemm() { }
|
||||
B2bGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
@ -223,7 +535,7 @@ struct B2bGemm {
|
||||
|
||||
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
|
||||
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
@ -233,9 +545,13 @@ struct B2bGemm {
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
run_with_swizzle(params, shared_storage, threadblock_swizzle);
|
||||
}
|
||||
|
||||
/// Executes one GEMM with an externally-provided swizzling function
|
||||
CUTLASS_DEVICE
|
||||
void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
@ -247,37 +563,64 @@ struct B2bGemm {
|
||||
return;
|
||||
}
|
||||
|
||||
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
|
||||
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
|
||||
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
|
||||
|
||||
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
|
||||
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
|
||||
|
||||
int offset_k_0 = 0;
|
||||
int offset_k_1 = 0;
|
||||
|
||||
int problem_size_k_0 = params.problem_size_0.k();
|
||||
int problem_size_k_1 = params.problem_size_1.k();
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_0 = min(
|
||||
problem_size_k_0,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_1 = min(
|
||||
problem_size_k_1,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
|
||||
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
|
||||
}
|
||||
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
|
||||
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
||||
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
||||
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
|
||||
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A0{
|
||||
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B0{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B1{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_1,
|
||||
offset_k_1,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_0 = min(
|
||||
params.problem_size_0.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_1 = min(
|
||||
params.problem_size_1.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
|
||||
|
||||
// Compute position within threadblock
|
||||
@ -286,34 +629,33 @@ struct B2bGemm {
|
||||
// Construct iterators to A and B operands
|
||||
typename B2bMma::IteratorA0 iterator_A0(
|
||||
params.params_A0,
|
||||
params.ref_A0.data(),
|
||||
ptr_A0,
|
||||
{params.problem_size_0.m(), problem_size_k_0},
|
||||
thread_idx,
|
||||
tb_offset_A0);
|
||||
|
||||
typename B2bMma::IteratorB0 iterator_B0(
|
||||
params.params_B0,
|
||||
params.ref_B0.data(),
|
||||
ptr_B0,
|
||||
{problem_size_k_0, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
tb_offset_B0);
|
||||
|
||||
typename B2bMma::IteratorB1 iterator_B1(
|
||||
params.params_B1,
|
||||
params.ref_B1.data(),
|
||||
ptr_B1,
|
||||
{problem_size_k_1, params.problem_size_1.n()},
|
||||
thread_idx,
|
||||
tb_offset_B1);
|
||||
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ref_Scale0.data(),
|
||||
ptr_Scale0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -323,7 +665,7 @@ struct B2bGemm {
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ref_Bias0.data(),
|
||||
ptr_Bias0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -332,14 +674,17 @@ struct B2bGemm {
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
OutputOp0 output_op_0(params.output_op_0);
|
||||
|
||||
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
|
||||
|
||||
@ -349,11 +694,9 @@ struct B2bGemm {
|
||||
src_accum.clear();
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
}
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
@ -376,23 +719,32 @@ struct B2bGemm {
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
||||
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
params.ref_C1.data(),
|
||||
ptr_C1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
@ -401,21 +753,21 @@ struct B2bGemm {
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D1(
|
||||
params.params_D1,
|
||||
params.ref_D1.data(),
|
||||
ptr_D1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C1 = iterator_D1;
|
||||
@ -427,14 +779,14 @@ struct B2bGemm {
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
@ -457,4 +809,3 @@ struct B2bGemm {
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -0,0 +1,157 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped B2b GEMMs
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount> {
|
||||
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using BaseParams = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
int32_t problem_count;
|
||||
void const *workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
|
||||
problem_count(0), workspace(nullptr), tile_count(0) { }
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0,
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1,
|
||||
int32_t problem_count,
|
||||
void const *workspace = nullptr,
|
||||
int32_t tile_count = 0
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
workspace(workspace),
|
||||
tile_count(tile_count)
|
||||
{}
|
||||
|
||||
/// Convert the B2b-GEMM-specific parameters to those used by the base class
|
||||
CUTLASS_HOST_DEVICE
|
||||
BaseParams to_base() const {
|
||||
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
|
||||
// shape of the grid used in the non-grouped B2b GEMM
|
||||
problem_sizes0,
|
||||
problem_count,
|
||||
workspace,
|
||||
tile_count);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
B2bGemmGroupedProblemVisitor(
|
||||
Params const ¶ms_,
|
||||
SharedStorage &shared_storage_,
|
||||
int32_t block_idx
|
||||
): Base (
|
||||
params_.to_base(),
|
||||
shared_storage_, block_idx),
|
||||
problem_sizes0(params_.problem_sizes0),
|
||||
problem_sizes1(params_.problem_sizes1)
|
||||
{}
|
||||
|
||||
/// Returns the problem size 0 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size0() const {
|
||||
GemmCoord problem = problem_sizes0[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
/// Returns the problem size 1 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size1() const {
|
||||
GemmCoord problem = problem_sizes1[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -63,7 +63,9 @@
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "kernel/grouped.h"
|
||||
#include "threadblock/default_b2b_mma.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -73,6 +75,9 @@ namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
@ -114,12 +119,12 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false
|
||||
bool SmemAccumulator = false,
|
||||
/// Whether or not the operation is grouped
|
||||
typename Enable = void
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
@ -161,17 +166,77 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
/// Partial specialization for Ampere Architecture with grouped operation
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
@ -188,7 +253,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
|
||||
};
|
||||
|
||||
|
||||
@ -228,8 +295,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -249,8 +314,9 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator
|
||||
Operator,
|
||||
false,
|
||||
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
@ -274,7 +340,7 @@ struct DefaultB2bGemm<
|
||||
Operator,
|
||||
EpilogueOutputOp0
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -287,7 +353,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -323,20 +389,17 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator> {
|
||||
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -360,7 +423,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -396,19 +459,17 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
|
||||
ThreadblockSwizzle, 2, Operator, false,
|
||||
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -418,7 +479,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -430,7 +491,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -112,22 +112,19 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, true> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
@ -179,8 +175,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
true
|
||||
> {
|
||||
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
|
||||
false,
|
||||
true
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -277,20 +270,17 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, true> {
|
||||
Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -350,19 +340,16 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
||||
ThreadblockSwizzle, 2, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
@ -0,0 +1,168 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief High-level interface for running a grouped version of a CUTLASS kernel
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// High-level interface for running a grouped version of a CUTLASS kernel
|
||||
template <
|
||||
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
|
||||
>
|
||||
struct GroupedKernel {
|
||||
public:
|
||||
|
||||
using BaseKernel = BaseKernel_;
|
||||
using Epilogue = typename BaseKernel::Epilogue;
|
||||
|
||||
/// Types that need to be exported to work properly with device::BaseGrouped
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
using Mma = typename BaseKernel::Mma;
|
||||
|
||||
using Arguments = typename BaseKernel::GroupedArguments;
|
||||
using Params = typename BaseKernel::GroupedParams;
|
||||
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
|
||||
|
||||
static int const kThreadCount = BaseKernel::kThreadCount;
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage {
|
||||
typename BaseKernel::SharedStorage kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GroupedKernel() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes a kernel-level GEMM in a loop
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
if (ProblemVisitor::kTransposed) {
|
||||
params.transpose();
|
||||
}
|
||||
|
||||
BaseKernel mma;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (swizzle.problem_visitor.next_tile()) {
|
||||
|
||||
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
|
||||
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
|
||||
|
||||
// Next tile
|
||||
swizzle.problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
MatrixCoord output_coord(
|
||||
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
|
||||
int kMblock = 4,
|
||||
int kNblock = 4
|
||||
>
|
||||
__global__ void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
||||
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
||||
int batch_idx = blockIdx.z;
|
||||
|
||||
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
|
||||
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
|
||||
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
|
||||
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
|
||||
|
||||
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, coord.column()});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
scale * ScalarType(tensor_in.at(coord)) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
|
||||
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
|
||||
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
|
||||
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
||||
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
int64_t npq = npq_start + m;
|
||||
|
||||
thread_n[m] = int(npq / PQ);
|
||||
|
||||
|
||||
int64_t residual = npq % PQ;
|
||||
thread_p[m] = int(residual / problem_size.Q);
|
||||
thread_q[m] = int(residual % problem_size.Q);
|
||||
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, thread_k});
|
||||
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, thread_k});
|
||||
|
||||
|
||||
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
scale * ScalarType(
|
||||
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
|
||||
) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
int const kMblock = 4;
|
||||
int const kNblock = 4;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
||||
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
||||
batch_count % std::numeric_limits<uint16_t>::max()
|
||||
);
|
||||
|
||||
kernel::TensorScaleBiasGemmBatched<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
ConvertOp,
|
||||
kMblock,
|
||||
kNblock
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias,
|
||||
batch_count,
|
||||
batch_stride_tensor_in,
|
||||
batch_stride_tensor_out,
|
||||
batch_stride_tensor_scale,
|
||||
batch_stride_tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -119,8 +119,10 @@ public:
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
using IteratorA = IteratorA0;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
using IteratorB = IteratorB0;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
@ -139,6 +141,10 @@ public:
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
///< Export Policy0 as the threadblock-level Mma's policy
|
||||
using Policy = Policy0;
|
||||
using Shape = Shape0;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
@ -188,6 +194,10 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
@ -641,6 +651,11 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
@ -871,7 +886,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -121,8 +121,10 @@ public:
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
using IteratorA = IteratorA0;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
using IteratorB = IteratorB0;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
@ -141,6 +143,10 @@ public:
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
///< Export Policy0 as the threadblock-level Mma's policy
|
||||
using Policy = Policy0;
|
||||
using Shape = Shape0;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
@ -194,6 +200,10 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
@ -664,6 +674,11 @@ public:
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
@ -855,7 +870,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user