Compare commits

...

79 Commits

Author SHA1 Message Date
44c704eae8 Doc updates for 3.2.2 2023-10-26 11:07:30 -07:00
6581237a48 fix issue/1138 2023-10-24 12:02:15 -07:00
5cd735c48e Fix Parallel Split-K on Gemm Operation Profiler (#1109)
* Debug and fix for parallel split-k in profiler

* restore debug files and remove prints
2023-09-26 17:28:00 -04:00
67ae8e0603 Change the position of minus sign in line1549 array.h (#1091)
when I use cutlass::epilogue:🧵:LinearCombinationSigmoid, I encounter the this error:
cutlass/include/cutlass/array.h(1549): error: no operator "-" matches these operands
Moving  operator "-" from line 1549 to 1548 can solve this error
2023-09-26 17:26:39 -04:00
14f69bddc8 [fix] fix comparison operator for integer_subbyte (#1090) 2023-09-26 17:26:12 -04:00
90d3b0fb18 CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release.

* Minor fix in gemm op profiler for raster order.

* Add scheduler mapping for raster order in the kernels.
2023-09-26 17:24:26 -04:00
e0aaa3c3b3 fix GmmaDescriptor print format string error (#1102) 2023-09-19 23:27:58 -04:00
8783c41851 Replace 0x1f with 0xffffffff in __shfl_sync (#1097)
This fixes compatibility with H100 and resolves #1094
2023-09-18 19:58:19 -04:00
6407bcdf0a fix matrix B indices (#1089) 2023-09-12 14:04:18 -04:00
a77b2c9cb8 style(examples): typo (#1080)
* Update ampere_tensorop_conv2dfprop.cu

learning cutlass, PR a typo.

* Update ampere_gemm_operand_reduction_fusion.cu
2023-09-11 10:13:22 -04:00
34bbadd3ff standarize fp8 generator (#1078) 2023-09-07 14:36:33 -04:00
88c0d7c726 make only visible on device (#1071) 2023-09-07 13:00:46 -04:00
e01b9b5029 Shard gemm reference templates into multiple TUs for parallel compilation (#1043)
* Split apart gemm reference templates into multiple TUs for parallel compilation

* remove old files

* better balancing of ref kernels across TUs

* remove 3 new added refcheck kernels and some un-necessary fp8 library instances to reduce lib size

* remove auto fp8 kernels

* remove some redundant kernels
2023-08-30 16:46:30 -04:00
34fd98056b fix cinttypes issue with STDC_FORMAT_MACROS (#1068)
* fix cinttypes issue with STDC_FORMAT_MACROS

* Update mma_sm90_desc.hpp

* Update mma_sm90_desc.hpp

---------

Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
2023-08-29 14:59:33 -04:00
3a8f57a3c8 Add simple hash and eq methods for gemm_operations. (#1053) 2023-08-27 20:41:57 -04:00
6673df0e48 fix typos (#1059) 2023-08-27 00:49:26 -04:00
7618e9bfd8 Fix numeric conversion warning (#1021)
* fix numeric conversion unused var

* update

---------

Co-authored-by: Lufang CHEN 陈橹方 <lufang.chen@nio.com>
2023-08-27 00:42:44 -04:00
a88c41cf8d Updates for 3.2 release (#1065) 2023-08-25 23:05:46 -04:00
27de343535 Add one Publication which is inspired by cutlass (#1022) 2023-08-22 10:00:17 -04:00
2a9fa23e06 Avoid cute::print compiler warnings with -Wformat-security (#1041)
Fixes issue #1040.
2023-08-18 14:38:27 -04:00
2e56cfabee fix typo (#1047) 2023-08-18 14:08:26 -04:00
3930f709ce Fix typo in 0x_gemm_tutorial.md (#1035) 2023-08-17 10:52:20 -04:00
7e5ee8b7bf [doc] fix: fix typos in the comment (#1049) 2023-08-16 11:39:25 -04:00
2d9a557427 torch.bfloat16 support in cutlass python (#1037)
* torch.bfloat16 support in cutlass python

* Update datatypes.py
2023-08-16 11:38:53 -04:00
4575443d44 CUTLASS 3.2 (#1024)
* CUTLASS 3.2
2023-08-07 20:50:32 -04:00
a0d787b746 Fix one publication (#1019) 2023-07-28 11:40:17 -04:00
d20f3a9542 spelling (#1007)
logicial -> logical
2023-07-20 14:41:11 -04:00
8e85580859 fix layout bug (#1006) 2023-07-19 14:26:01 -04:00
146d314057 Update fMHA kernels (#992)
* Update fMHA kernels

Upstream recent changes to fMHA that we did in xFormers.
Previous version in CUTLASS: facebookresearch/xformers@b6be33a
Updating to: facebookresearch/xformers@55a4798

* minor changes

* make var work

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-07-12 22:30:46 -04:00
f679663224 Add RMS norm (#979) 2023-07-10 21:31:27 -04:00
e066ced33b fix epilogue iterator error (#995)
* fix epilogue iterator error

* fix epilogue iterator error

---------

Co-authored-by: maxiao <maxiao@cowarobot.com>
2023-07-10 21:30:31 -04:00
9b923dd4c4 fix minor typos (#984) 2023-07-05 09:23:01 -04:00
f6d42f2dd0 add library_dirs (#977) 2023-06-14 12:09:12 -04:00
473a67073e Fix Int8 and TF32 generator (#976) 2023-06-12 12:32:52 -04:00
87349d3496 Add grouped b2b GEMM (#970) 2023-06-05 17:16:57 -04:00
fde824af21 Update Hopper performance plot for CUTLASS 3.1 + CTK 12.1 (#967) 2023-06-01 14:52:40 -04:00
7dbf423763 Add conversion from ElementBias to ElementCompute (#961) 2023-05-26 23:08:36 -04:00
6f47420213 Update README.md 2023-05-24 12:40:31 -04:00
4638250469 Update CHANGELOG.md 2023-05-24 12:39:42 -04:00
7859fe322a Update PUBLICATIONS.md 2023-05-24 12:36:12 -04:00
d3e72719b4 Add support for sparse GEMM with row broadcasted bias vector (#951) 2023-05-24 10:25:05 -04:00
b4ab501767 Adds CUDA path for x86-64 (#957) 2023-05-24 10:21:25 -04:00
f079619f5e More updates for 3.1 (#958)
* Updates for 3.1

* Minor change

* doc link fix

* Minor updates
2023-05-24 10:17:16 -04:00
13f413493a Stream-K with broadcast (#892)
* [WIP] GEMM StreamK w/ Fused Epilogue

* Adds Gemm Streamk with Fused Epilogue kernel level struct.
  * Mostly based on Gemm with Fused Epilogue,
  * Requires a new epilogue
  * Work in progress

* [WIP] StreamK support for GemmUniversalWithBroadcast

* Just based off of how StreamK is allowed in GemmUniversal
  * Untested and a work in progress

* Minor fixes

* [WIP] It compiles!

It is almost certainly incorrect, but we're past getting the templates
to match, so checkpointing.

* Correction to reference kernel

* Fix typo

* Added MSE measurement

* Switch back to reference kernel + host for loop

Still WIP. Now we're getting even a larger MSE, but it's both on
basic Split-K and Stream-K.

* Fix typos

* Fix broadcast vector + requested changes

* Comment typo

* Small int option and more

* Fix incorrect condition on source needed

* Requested changes

* I think I got it?

* Bias vector should be stride 0

* Two source added!

* Typos

* Merge examples

* Bring back vector row offset

Just to ensure consistency with universal gemm with fused epilogue

* Base arguments and params structs for StreamK

* StreamK epilogue with broadcast now inherits the original

* undo params_streamk_base.h

---------

Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-05-22 19:05:06 -04:00
6fbc0d3380 Update layout.md 2023-05-17 20:12:58 -04:00
b97404837e Adding 128x256 tile for 16b input datatype WGMMA gemm (#950) 2023-05-17 17:13:23 -04:00
e2953d47c5 Update gemm_api.md 2023-05-12 15:37:31 -04:00
wll
19c4a4815e replace division with multiplication in GELU (#942) 2023-05-12 10:57:18 -04:00
fcfbd23e26 Fix host compilation of cute::cast_smem_ptr_to_uint. (#940)
* Remove references to device-only intrinsics when compiling for host.

Currently, we attempt to use the `__device__`-only functions
`__cvta_generic_to_shared` and `__nvvm_get_smem_pointer` when compiling
`cute::cast_smem_ptr_to_uint` for the host on Clang. This results in a
compilation error, as expected. This commit changes the definition of
the `*_ACTIVATED` macros so that they are only true when `__CUDA_ARCH__`
is defined; that is, when compiling for the device.

Additionally, the declaration of `__nvvm_get_smem_pointer`
is currently only visible during the device compilation pass when
compiling with NVCC; this commit makes the declaration visible during
host compilation with the `__device__` annotation.

* Annotate cute::cast_smem_ptr_to_uint as device-only.

The implementation of `cute::cast_smem_ptr_to_uint` is currently an
unchecked failure on host code, and the only host implementation I can
think of -- casting a probably-64-bit pointer to 32 bits somehow --
doesn't make sense to implement. This commit marks this function as
device-only so that it can't be accidentally used on host code.

* small change

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-05-10 00:06:54 -04:00
b250faccd3 Make operator() const-correct and add missing static functions. (#936)
* Make operator() const-correct and add missing static functions.

Currently, `*Converter::operator()` requires a mutable object to invoke,
and there are missing `static result_type convert(source_type const &
source)` overloads for certain partial specializations of `*Converter`
objects. This commit makes `operator()` const-correct and adds missing
function overloads where appropriate.

* minor changes

* format

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-05-09 16:33:01 -04:00
24c8b7d8a2 Fix cuTE compilation with clang (#939)
- clang 1.14 complains about missing function from a host call:
  cutlass/include/cute/arch/util.hpp:106:32: error: no matching function for call to '__cvta_generic_to_shared'
  return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
- fixes this by defining CUTE_HOST_DEVICE for clang as well

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
2023-05-09 09:51:45 -04:00
7c04f95415 Updates for 3.1 (#932) 2023-04-29 09:34:27 -04:00
6f8596ce3f Add missing #include directive to get access to cutlass::epilogue:🧵:ScaleType. (#925)
Currently, the `LinearCombinationClamp` header file is not standalone,
and must have the definition of `cutlass::epilogue:🧵:ScaleType`
already available when it is `#include`d.
2023-04-28 20:02:41 -04:00
fe2f491dd7 Get SM count with cudaDeviceGetAttribute in KernelHardwareInfo (#927) 2023-04-28 13:23:23 -04:00
df02482f1d Add missing schedules argument in SM90 fp16 op generation (#920) 2023-04-26 16:44:49 -04:00
180c5629bf Add missing checks for NVRTC in CuTe (#921) 2023-04-25 12:52:43 -04:00
e36912f961 Fix for dangling references in the MHA example (#918) 2023-04-19 21:35:46 -04:00
9a83bd3381 CUTLASS 3.1 Python interface documentation (#917)
* Add 12.1 Dockerfile

* Add 3.1 docs
2023-04-18 15:11:35 -04:00
54bebe417d Fix some typos in CuTe tutorials (#912) 2023-04-17 16:00:51 -04:00
43cfbe0086 Allow L2 prefect for clang compiler (#914) 2023-04-15 01:23:22 -04:00
4a68cf748e added support of b2b bmm (#849)
* added support of b2b bmm

* fixed arguments and params structures

* added batch_count argument

* removed SplitKSerial and added new test case with b2b bmm

* fixed support of Kbatched and added new test case with batch stride

* added batch support for bias and scale

* make test

* small changes

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-04-14 23:20:02 -04:00
d572cc1aab CUTLASS 3.1 (#915)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-04-14 23:19:34 -04:00
9b8166e3f0 fMHA: Add backward pass (#844)
* fMHA: Add backward pass

* Better checks for strides/alignments

* Remove fb-internal URL

* torch.Tensor.untyped_storage requires pytorch 2.0+

* minor changes

* make test

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-04-06 20:44:58 -04:00
e2d439ee7e Add tile_n=32 and tile_k=32 kernels in generator.py (#858) 2023-04-06 10:00:52 -04:00
0435979f59 Remove const from 3.x GemmUniversalAdapter::operator() (#905) 2023-04-03 20:30:51 -04:00
2ba1ef10be Increase max dynamic SMEM size in GemmSoftmax (#903) 2023-04-03 10:01:12 -04:00
0964bdb64c update gemm and conv2d cmdline --help output (#878) 2023-04-01 11:38:13 -04:00
ecbd24566c Enable shared memory intrinsics and ldmatrix PTX on Clang. (#754)
* Enable shared memory intrinsics and ldmatrix PTX on Clang.

This commit adds preprocessor checks to enable the shared memory
intrinsics `__cvta_generic_to_shared` and `__nvvm_get_smem_pointer`, as
well as the `ldmatrix` PTX instructions, on Clang. Preventing these
intrinsics from being used is a significant latency regression on Clang.

* refine the macro

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-03-31 21:42:24 -04:00
660a05f581 fix split_k_mode and add reduction kernel for f16 input/accum/output (#896) 2023-03-30 15:31:08 -04:00
bc36122c3f [layout] Fix AffineRank2ColumnMajor::packed() (#879)
* [layout] Fix AffineRank2ColumnMajor::packed()

* correct affine2row::packed

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-03-29 11:59:48 -04:00
15d9d31f1f CUTLASS 3.0 Hopper GEMMs are GETTs in disguise (#897) 2023-03-29 10:42:40 -04:00
1eef5c3cf1 add guards for __CUDA_ARCH__ >= 530 (#891)
* add guards for sm>=70

* drop guard to 530
2023-03-28 17:47:10 -04:00
87070b6d51 add a CUTLASS publication (#893)
* add bytetransformer

* update arxiv link

* re-order
2023-03-28 17:06:57 -04:00
77549ae6c8 Update PUBLICATIONS.md
msft moe paper
2023-03-25 21:17:05 -04:00
42290f5d1c Fix for dangling pointers (#885) 2023-03-25 01:15:14 -04:00
209faf7b94 remove spurious comma (#871) 2023-03-20 17:25:27 -04:00
6116706c96 Set batch_strides on Params::update (#883) 2023-03-20 17:07:47 -04:00
2670b973dd Fix sign-compare warning in reorder_array (#869)
`std::vector<T>::size_type` is unsigned type, so let's iterate over unsigned type as well


Discovered, while trying to enable PyTorch building without `-Wno-sign-compare` warning suppression, see https://github.com/pytorch/pytorch/actions/runs/4418987999/jobs/7746850762#step:10:10532
2023-03-20 17:07:24 -04:00
af332d4aa9 Add missing comma in cutlass/arch/mma_sm90.h (#862) 2023-03-14 12:04:28 -04:00
1019 changed files with 134946 additions and 39234 deletions

View File

@ -1,5 +1,52 @@
# NVIDIA CUTLASS Changelog
## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2) (2023-10-25)
* Fixes illegal memory access issue [1138](https://github.com/NVIDIA/cutlass/issues/1138) hit by FlashAttention tests in PyTorch.
## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22)
* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
* SM80 EVT support in C++ and Python.
* Other SM90 epilogue improvements.
* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
* SM90 TF32 kernel improvements for all layouts.
* SM90 rasterization direction support in the CUTLASS profiler.
* Improvement for CUTLASS profiler build times.
* Remove Python-C++ bindings.
## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03)
* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision.
* New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
* [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
* New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0.
* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
* An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
* Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* [Row Broadcast](include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue.
* The GitHub branch is renamed from `master` to `main` in this release.
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23)
* [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors.
@ -57,7 +104,7 @@
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
* [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler
* [Python-based instance emitter](/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
* Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)

View File

@ -26,7 +26,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
cmake_minimum_required(VERSION 3.19 FATAL_ERROR)
cmake_policy(SET CMP0112 NEW)
if(cutlass_LOADED)
# If CUTLASS has been previously fetched and loaded, don't do it again.
@ -39,7 +40,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
project(CUTLASS VERSION 3.0.0 LANGUAGES CXX)
project(CUTLASS VERSION 3.2.2 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 11.3)
@ -58,6 +59,8 @@ endif()
find_package(Doxygen QUIET)
################################################################################
#
# CUTLASS 3.x requires C++17
#
@ -79,16 +82,41 @@ endif()
message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}")
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
# 0 - Sanity, 1 - Release-Quality, 2 - Exhaustive
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
# Install cutlass_library Python package
execute_process(
WORKING_DIRECTORY ${CUTLASS_DIR}/python
COMMAND ${Python3_EXECUTABLE} ${CUTLASS_DIR}/python/setup_library.py develop --user
RESULT_VARIABLE cutlass_lib_GENERATOR_INSTALL_RESULT
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
)
if(NOT cutlass_lib_GENERATOR_INSTALL_RESULT EQUAL 0)
message(FATAL_ERROR "Error installing cutlass_library package. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log")
endif()
################################################################################
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
if(CUTLASS_ENABLE_HEADERS_ONLY)
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
set(CUTLASS_ENABLE_TOOLS_INIT ON)
set(CUTLASS_ENABLE_LIBRARY_INIT OFF)
set(CUTLASS_ENABLE_TESTS_INIT OFF)
else()
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
set(CUTLASS_ENABLE_TOOLS_INIT ON)
set(CUTLASS_ENABLE_LIBRARY_INIT ON)
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ON)
else()
set(CUTLASS_ENABLE_TESTS_INIT OFF)
endif()
endif()
set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.")
@ -97,19 +125,11 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library")
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Proformance")
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
else()
set(CUTLASS_ENABLE_TESTS_INIT OFF)
endif()
set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Performance")
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
if (CUTLASS_ENABLE_TESTS)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
endif()
set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests")
################################################################################
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
@ -124,6 +144,17 @@ endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
# Find unsupported and deprecated compute capabilities
if (CUTLASS_NVCC_ARCHS_SUPPORTED)
set(CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS})
list(REMOVE_ITEM CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS_SUPPORTED})
if (CUTLASS_NVCC_ARCHS_UNSUPPORTED)
message(WARNING "Using unsupported or deprecated compute capabilities ${CUTLASS_NVCC_ARCHS_UNSUPPORTED}. Support may be removed in future versions.")
endif()
else()
message(WARNING "No supported compute capabilities for CUDA ${CUDA_VERSION}.")
endif()
# Special policy introduced in CMake 3.13
if (POLICY CMP0076)
cmake_policy(SET CMP0076 NEW)
@ -161,8 +192,8 @@ if(WIN32)
endif()
if (WIN32)
# Enable more warnings and treat as errors
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX)
# Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors.
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3)
# Disable warning on Unicode characters
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819)
@ -185,15 +216,16 @@ set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
################################################################################
#
# CUTLASS generator cmake configuration
#
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
# Test Levels L0, L1, L2
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
################################################################################
set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")
@ -213,6 +245,8 @@ if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
endif()
################################################################################
#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
#
@ -262,6 +296,8 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
endif()
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
@ -287,9 +323,10 @@ if (CUTLASS_ENABLE_OPENMP_TESTS)
message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.")
endif()
endif()
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
if(UNIX)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
endif()
# Don't leak lineinfo in release builds
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
@ -352,6 +389,28 @@ if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
cmake_policy(SET CMP0104 NEW)
endif()
if (MSVC)
# MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
# appropriate value given the requested --std option. This fixes a compilation issue mismatch
# between GCC/Clang and MSVC.
#
# error : a constexpr function cannot have a nonliteral return type "dim3"
#
# See https://developercommunity.visualstudio.com/t/msvc-incorrectly-defines-cplusplus/139261
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
endif()
# Some tests require this build option in order to link.
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj")
endif()
function(cutlass_apply_cuda_gencode_flags TARGET)
set(options)
set(oneValueArgs)
@ -466,7 +525,8 @@ endfunction()
# GLOB for CUTLASS header files. Should we use a static list instead?
file(GLOB_RECURSE CUTLASS_INCLUDE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} include/cutlass/*.h)
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h)
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h include/cutlass/*.hpp include/cutlass/*.inl)
file(GLOB_RECURSE CUTLASS_CUTE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cute/*.h*)
file(GLOB_RECURSE CUTLASS_NVRTC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/test test/unit/nvrtc/kernel/*.h)
###################################################################################################
@ -526,11 +586,17 @@ target_include_directories(
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${cute_SOURCE_DIR}/examples>
)
# Mark CTK headers as system to supress warnings from them
target_include_directories(
CUTLASS
SYSTEM INTERFACE
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
)
install(
DIRECTORY
${CUTLASS_INCLUDE_DIR}/
@ -587,6 +653,11 @@ endif()
include(CTest)
enable_testing()
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
endif()
if (NOT TARGET test_all)
add_custom_target(test_all)
endif()
@ -623,7 +694,7 @@ endif()
################################################################################
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.config.cmake)
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake)
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
function(cutlass_add_executable_tests NAME TARGET)
@ -637,14 +708,16 @@ function(cutlass_add_executable_tests NAME TARGET)
# DEPENDS: A list of targets or files on which this test is dependent.
# DEPENDEES: A list of targets which should depend on this test.
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of
# to pass to the test executable. A unique test is generated for each set of
# options given. If this option is not used, a single test with no arguments is generated.
# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for
# generating the full variable name to be referenced.
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
# test results to speed up test runtime.
#
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE)
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@ -652,6 +725,9 @@ function(cutlass_add_executable_tests NAME TARGET)
set(__DISABLE_TESTS OFF)
endif()
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
if (__RESULT_CACHE_FILE)
add_custom_command(
@ -688,7 +764,6 @@ function(cutlass_add_executable_tests NAME TARGET)
endif()
list(LENGTH __TEST_COMMAND_OPTIONS CMD_COUNT)
set(CMD_IDX 0)
if (CMD_COUNT GREATER 1)
add_custom_target(${NAME} DEPENDS ${TARGET} ${__DEPENDS})
@ -697,12 +772,22 @@ function(cutlass_add_executable_tests NAME TARGET)
endforeach()
endif()
foreach(CMD_OPTIONS ${__TEST_COMMAND_OPTIONS})
if (CUTLASS_INSTALL_TESTS)
set(_INLINE_PER_TEST_CODE)
file(READ "${PROJECT_SOURCE_DIR}/cmake/CTestTestfile.test.configure.cmake" _INLINE_PER_TEST_CODE_TEMPLATE)
endif()
set(TEST_GROUP_NAME ${NAME})
foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS)
if (CMD_COUNT GREATER 1)
set(TEST_NAME ${NAME}_${CMD_IDX})
string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TEST_NAME)
else()
set(TEST_NAME ${NAME})
string(TOLOWER "${NAME}" TEST_NAME)
endif()
# The following rigmarole is needed to deal with spaces and possible quotes in
@ -711,14 +796,14 @@ function(cutlass_add_executable_tests NAME TARGET)
# preserves any quotes. Note, they have to be in this order for it to work for
# all the use cases below.
set(CMD_OPTIONS ${${CMD_OPTIONS}})
list(JOIN CMD_OPTIONS " " TEST_COMMAND_OPTIONS)
separate_arguments(CMD_OPTIONS)
set(TEST_COMMAND_OPTIONS ${${__TEST_COMMAND_OPTIONS_PREFIX}${CMD_OPTIONS_VAR}})
list(JOIN TEST_COMMAND_OPTIONS " " TEST_COMMAND_OPTIONS)
separate_arguments(TEST_COMMAND_OPTIONS)
add_custom_target(
${TEST_NAME}
COMMAND
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${TEST_COMMAND_OPTIONS}
DEPENDS
${TARGET}
)
@ -731,41 +816,48 @@ function(cutlass_add_executable_tests NAME TARGET)
add_dependencies(${DEPENDEE} ${TEST_NAME})
endforeach()
add_test(
NAME c${TEST_NAME}
COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
)
set(TEST_NAME c${TEST_NAME})
string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY)
string(APPEND _INLINE_PER_TEST_CODE "${_TEST_CODE}")
set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS})
endforeach()
# To run the tests from an install package with tests enabled, we need to generate test files
# that don't rely on the current directory structure in build.
set(TEST_NAME c${NAME})
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
file(MAKE_DIRECTORY ${TEST_GEN_DIR})
set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT ON)
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
set(TEST_EXE_PATH $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format.
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY)
# The following line imports the tests for immediate run via `make test`.
include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake)
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
if (CUTLASS_INSTALL_TESTS)
# To run the tests from an install package with tests enabled, we need to generate test files
# that don't rely on the current directory structure in build.
file(GENERATE
OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in"
)
set(TEST_NAME c${TEST_NAME})
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" @ONLY)
install(
FILES "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake"
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/${TEST_NAME}
RENAME CTestTestfile.${TEST_NAME}.cmake
)
file(GENERATE
OUTPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
INPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake"
)
install(
FILES "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/
)
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
endif()
math(EXPR CMD_IDX "${CMD_IDX} + 1")
endforeach()
endfunction()
if (CUTLASS_ENABLE_TOOLS)
@ -774,6 +866,7 @@ if (CUTLASS_ENABLE_TOOLS)
add_dependencies(test_all test_profiler)
endif()
endif()
if (CUTLASS_ENABLE_EXAMPLES)
add_subdirectory(examples)
add_dependencies(test_all test_examples)
@ -781,38 +874,27 @@ endif()
if (CUTLASS_ENABLE_TESTS)
add_subdirectory(test)
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
add_dependencies(test_all test_unit)
endif()
endif()
if (CUTLASS_INSTALL_TESTS)
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/cmake")
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")
file(WRITE "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "# Generated File\n")
file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n")
foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
file(APPEND "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
endforeach()
install(
FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake"
FILES "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake"
DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/"
)
endif()
#? install(
#? FILES ${CMAKE_BINARY_DIR}/CTestTestfile.cmake
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
#? )
#?
#? install(
#? DIRECTORY
#? ${CMAKE_BINARY_DIR}/tools
#? ${CMAKE_BINARY_DIR}/test
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
#? FILES_MATCHING PATTERN "CTestTestfile.cmake"
#? )
################################################################################
include(CMakePackageConfigHelpers)
@ -838,3 +920,4 @@ install(
################################################################################
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake)

View File

@ -76,6 +76,7 @@ find_library(
PATHS
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib/x86_64-linux-gnu
lib/x64
lib64
lib
@ -120,6 +121,7 @@ find_library(
PATHS
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib/x86_64-linux-gnu
lib/x64
lib64
lib
@ -226,7 +228,14 @@ else()
endif()
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
if (MSVC)
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 8)
else()
set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 16)
endif()
set(CUTLASS_UNITY_BUILD_BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT} CACHE STRING "Batch size for unified source files")
function(cutlass_unify_source_files TARGET_ARGS_VAR)
@ -296,10 +305,10 @@ function(cutlass_add_library NAME)
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
add_library(${NAME} ${TARGET_SOURCE_ARGS})
add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
else()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS} "")
endif()
cutlass_apply_standard_compile_options(${NAME})

View File

@ -2,12 +2,22 @@
## 2023
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
## 2022
- ["GPU Load Balancing"](https://arxiv.org/abs/2212.08964). Muhammad Osama. _Doctoral dissertation, University of California, Davis_, December 2022.
- ["Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"](https://arxiv.org/abs/2211.10017). Young Jin Kim, Rawn Henry, Raffy Fahim, Hany Hassan Awadalla. _Proceedings of the Third Workshop on Simple and Efficient Natural Language Processing_, December 2022.
- ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022.
- ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022.
@ -18,7 +28,7 @@
- ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021.
- ["Real-time Neural Radiance Caching for Path Tracing"](https://d1qx31qr3h6wln.cloudfront.net/publications/paper_4.pdf). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
- ["Real-time Neural Radiance Caching for Path Tracing"](https://dl.acm.org/doi/abs/10.1145/3450626.3459812). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
## 2020

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 3.0
# CUTLASS 3.2
_CUTLASS 3.0 - January 2023_
_CUTLASS 3.2 - August 2023_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
@ -31,33 +31,39 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](/media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data.
CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data.
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations.
The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning.
CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design
and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md).
In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.
# What's New in CUTLASS 3.0
# What's New in CUTLASS 3.2
CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture.
CUTLASS 3.2.0 is an update to CUTLASS adding:
- New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm).
- New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
- [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
- Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
- Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
- [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
- New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
- Support for Windows (MSVC) builds.
- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md)
- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md).
- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe.
- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe.
- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs.
- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA.
- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality.
- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu).
- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format
for kernel procedural names.
- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release.
CUTLASS 3.2.1 is an update to CUTLASS adding:
- Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
- SM80 EVT support in C++ and Python.
- Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
- Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
- SM90 TF32 kernel improvements for all layouts.
- SM90 rasterization direction support in the CUTLASS profiler.
- Improvement for CUTLASS profiler build times.
## New architecture, compiler, and CUDA Toolkit requirements
CUTLASS 3.2.2 is a minor update to CUTLASS adding:
- Bug fix for illegal memory access issue hit by Flash Attention tests in PyTorch. See [1138](https://github.com/NVIDIA/cutlass/issues/1138) for details.
Minimum requirements:
@ -65,7 +71,7 @@ Minimum requirements:
- Compiler: Must support at least C++17
- CUDA Toolkit version: 11.4
CUTLASS 3.0 *removes support* for the following:
Starting from CUTLASS 3.0, CUTLASS removed support for the following:
- Maxwell and Pascal GPU architectures
- Ubuntu 16.04
@ -76,7 +82,7 @@ CUTLASS 3.0 *removes support* for the following:
# Performance
<p align="center"><img src=media/images/cutlass-3.0-gemm-peak-performance.png></p>
<p align="center"><img src=media/images/cutlass-3.1-gemm-peak-performance.png></p>
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit peak performance comparable to cuBLAS for scalar GEMM
@ -87,20 +93,21 @@ an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere
and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture).
CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads).
Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and
[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.
<p align="center"><img src=media/images/cutlass-2.9-implicit-gemm-performance.png></p>
When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad)
kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/)
as shown in the above figure. Tensor Core operations are still implemented using CUDA's
as shown in the above figure. Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
# Compatibility
CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8.
performs best when built with the [**CUDA 12.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1.
## Operating Systems
We have tested the following environments.
@ -110,8 +117,10 @@ We have tested the following environments.
| Ubuntu 18.04 | GCC 7.5.0 |
| Ubuntu 20.04 | GCC 10.3.0 |
| Ubuntu 22.04 | GCC 11.2.0 |
| Windows 10.0 | Visual Studio 2019 v16.11.27 |
Note: We plan to add Windows (MSVC) & Clang compiler support soon.
Note: We plan to add Clang compiler support soon.
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
## Hardware
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs.
@ -131,9 +140,9 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be
## Target Architecture
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability).
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error.
The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12 or 11.8, the kernel is expected to fail with a runtime error.
```
cmake .. -DCUTLASS_NVCC_ARCHS="90a"
@ -178,7 +187,8 @@ CUTLASS is a header-only template library and does not need to be built to be us
projects. Client applications should target CUTLASS's `include/` directory in their include
paths.
CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12.
CUTLASS unit tests, examples, and utilities can be build with CMake.
The minimum version of CMake is given in the [Quickstart guide](media/docs/quickstart.md).
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system.
@ -514,7 +524,7 @@ reference_device: Passed
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
- [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples)
- [Implicit GEMM conovlution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
- [Implicit GEMM convolution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
@ -558,4 +568,3 @@ SPDX-License-Identifier: BSD-3-Clause
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -1,21 +0,0 @@
# Generated file
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
else()
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
endif()
if (NOT "@TEST_EXE_DIR@" STREQUAL "")
set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@)
else()
set(TEST_EXE_PATH @TEST_EXE@)
endif()
add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "")
set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@")
endif()
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)

View File

@ -0,0 +1,14 @@
# Generated file
set(TEST_EXE_PATH @TEST_EXE_PATH@)
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@)
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
else()
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
endif()
@_INLINE_PER_TEST_CODE@

View File

@ -0,0 +1,15 @@
if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT)
# The longform/extended format allows generator expressions to be
# expanded property and is useful in contexts where the files need
# to be immediately included into being-processed cmake code.
add_test(NAME @TEST_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
else()
add_test(@TEST_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
endif()
if (TEST_EXE_WORKING_DIRECTORY)
set_tests_properties(@TEST_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}")
endif()
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)

View File

@ -2,6 +2,11 @@ get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
include(CMakeFindDependencyMacro)
if(NOT TARGET nvidia::cutlass::CUTLASS)
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
if(TARGET nvidia::cutlass::CUTLASS)
return()
endif()
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
# For backward compatibility with the old name
add_library(cutlass_lib ALIAS cutlass_library)

View File

@ -9,7 +9,7 @@ endif()
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG 0fe9660
GIT_TAG v1.13.0
)
FetchContent_GetProperties(googletest)

View File

@ -291,8 +291,8 @@ int run() {
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue>
int32_t,
int32_t>
gemm_device;
// Launch device reference gemm kernel
@ -355,4 +355,3 @@ int main() {
return run();
}

View File

@ -143,7 +143,6 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = int32_t; // Data type of accumulator
@ -555,6 +554,7 @@ Result profile_convolution(Options const &options) {
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
ElementOutput,
cutlass::NumericConverterClamp<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
@ -674,7 +674,6 @@ Result profile_convolution(Options const &options) {
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
@ -761,11 +760,7 @@ int main(int argc, char const **args) {
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -27,7 +27,10 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# This example depends on the CUTLASS Library
#
if (CUTLASS_ENABLE_LIBRARY)
# Planar Complex GEMM example
cutlass_example_add_executable(
@ -35,11 +38,6 @@ cutlass_example_add_executable(
planar_complex.cu
)
#
# This example depends on the CUTLASS Library
#
target_link_libraries(
10_planar_complex
PRIVATE
@ -48,3 +46,4 @@ target_link_libraries(
cuda
)
endif()

View File

@ -27,7 +27,10 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# This example depends on the CUTLASS Library
#
if (CUTLASS_ENABLE_LIBRARY)
# Planar Complex Array GEMM example
cutlass_example_add_executable(
@ -35,11 +38,6 @@ cutlass_example_add_executable(
planar_complex_array.cu
)
#
# This example depends on the CUTLASS Library
#
target_link_libraries(
11_planar_complex_array
PRIVATE
@ -48,3 +46,4 @@ target_link_libraries(
cuda
)
endif()

View File

@ -64,6 +64,7 @@ endforeach()
foreach(FUSION_GEMM_EXAMPLE
fused_two_gemms_f16_sm75_rf
fused_two_gemms_f16_sm75_shmem
fused_two_gemms_grouped_f16_sm80_rf
fused_two_gemms_f16_sm80_rf
fused_two_gemms_f16_sm80_shmem
fused_two_gemms_s8_sm75_rf

View File

@ -1,11 +1,11 @@
# Introduction
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
<p align="center"><img src=/media/images/13_example_fusion.png></p>
When running two unfused GEMM/Conv operations, each operation loads one input
activation matrix, one weight matrix (or filter matrix) from the memory and then
When running two unfused GEMM/Conv operations, each operation loads one input
activation matrix, one weight matrix (or filter matrix) from the memory and then
stores the result activation matrix back to the memory.
When the two GEMM/Conv operations are fused together, the mainloops of the two
@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
2 GEMMs/Convs.
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
input activation, the example enforces the following two constraints:
- thread_block_tile_N = problem_N
- thread_block_tile_N = problem_N
<p align="center"><img src=/media/images/13_example_block_resident_fusion.png></p>
@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
operation can be fully block-resident.
- warp_tile_N = thread_block_tile_N
- warp_tile_N = thread_block_tile_N
<p align="center"><img src=/media/images/13_example_rf_resident_fusion.png></p>
@ -82,7 +82,7 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
# Copyright

View File

@ -42,6 +42,7 @@
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun
//
B2bNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
ElementCompute,
ElementCompute,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
ElementCompute,
ElementCompute,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
@ -399,9 +400,9 @@ struct B2bFusedGemmRun
//
B2bFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@ -412,7 +413,7 @@ struct B2bFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -420,11 +421,11 @@ struct B2bFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@ -453,70 +454,90 @@ struct B2bFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
ElementCompute,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
//
typename B2bGemm::Arguments arguments{
mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
tensor_B1.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
batch_count,
};
B2bGemm b2b_gemm_op;
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator
>(
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
ElementAccumulator(0)
ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
);
cutlass::reference::device::TensorScaleBiasGemm<
cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
alpha1, //intermediate alpha=1
reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(),
cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref()
reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@ -680,7 +729,7 @@ struct B2bFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -694,7 +743,7 @@ struct B2bFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()

View File

@ -0,0 +1,450 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Containers for running grouped back-to-back GEMMs
*/
#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
#include "helper.h"
#define CHECK_GT(val1, val2) \
if((val1) <= (val2)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
#define CHECK_TRUE(val) \
if(!(val)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
////////////////////////////////////////////////////////////////////////////////
template <typename B2bGemm_>
struct B2bFusedGroupedGemmRun
{
using B2bGemm = B2bGemm_;
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Scale;
cutlass::Distribution::Kind init_Bias;
uint64_t seed;
//
// Methods
//
B2bFusedGroupedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_),
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
int problem_count = (int)problem_sizes_0.size();
std::vector<HostTensorA> host_tensor_A0(problem_count);
std::vector<HostTensorB> host_tensor_B0(problem_count);
std::vector<HostTensorC> host_tensor_C0(problem_count);
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
std::vector<HostTensorB> host_tensor_B1(problem_count);
std::vector<HostTensorC> host_tensor_C1(problem_count);
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
std::vector<HostTensorC> host_tensor_D1(problem_count);
std::vector<HostTensorZ> host_tensor_Z(problem_count);
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
for (int i = 0; i < problem_count; ++i) {
//
// Allocate the GEMM workspace
//
auto problem_size_0 = problem_sizes_0[i];
auto problem_size_1 = problem_sizes_1[i];
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
if (alpha0 == ElementCompute(0)) //per-channel scale
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
if (alpha0 == ElementCompute(0)) //per-channel scale
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
cutlass::reference::host::TensorFill(
host_tensor_D1.at(i).host_view());
cutlass::reference::host::TensorFill(
host_tensor_ref_D0.at(i).host_view());
cutlass::reference::host::TensorFill(
host_tensor_ref_D1.at(i).host_view());
host_tensor_A0.at(i).sync_device();
host_tensor_B0.at(i).sync_device();
host_tensor_C0.at(i).sync_device();
if (alpha0 == ElementCompute(0)) //per-channel scale
host_tensor_Scale0.at(i).sync_device();
host_tensor_Bias0.at(i).sync_device();
host_tensor_B1.at(i).sync_device();
host_tensor_C1.at(i).sync_device();
host_tensor_Bias1.at(i).sync_device();
host_tensor_D1.at(i).sync_device();
host_tensor_ref_D0.at(i).sync_device();
host_tensor_ref_D1.at(i).sync_device();
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
if (alpha0 == ElementCompute(0)) //per-channel scale
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
}
//
// Initialize the GEMM operator
//
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
device_ref_A0.copy_from_host(ref_A0.data());
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
device_ref_B0.copy_from_host(ref_B0.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
device_ref_C0.copy_from_host(ref_C0.data());
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
device_ref_Scale0.copy_from_host(ref_Scale0.data());
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
device_ref_Bias0.copy_from_host(ref_Bias0.data());
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
device_ref_B1.copy_from_host(ref_B1.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
device_ref_C1.copy_from_host(ref_C1.data());
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
device_ref_Bias1.copy_from_host(ref_Bias1.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
device_ref_D1.copy_from_host(ref_D1.data());
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
B2bGemm b2b_gemm_op;
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
if (!threadblock_count) {
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
return false;
}
typename B2bGemm::Arguments arguments{
problem_count,
device_problem_sizes_0.get(),
device_problem_sizes_1.get(),
device_ref_A0.get(),
device_ref_B0.get(),
device_ref_C0.get(),
device_ref_Scale0.get(),
device_ref_Bias0.get(),
device_ref_B1.get(),
device_ref_C1.get(),
device_ref_D1.get(),
{alpha0, beta0},
{alpha1, beta1},
threadblock_count
};
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
if(status != cutlass::Status::kSuccess) {
std::cout << "Problem sizes not supported.\n"
<< "Requirments:\n"
<< " problem_size_0.M = problem_size_1.M\n"
<< " problem_size_0.N = problem_size_1.K\n"
<< " ThreadblockShape0::kN = problem_size_0.N\n"
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
}
status = b2b_gemm_op.initialize(arguments);
CUTLASS_CHECK(status);
for(int i = 0; i < warm_ups; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
//
// Run the GEMM
//
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
for(int i = 0; i < runs; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaDeviceSynchronize();
float gemmTime;
cudaEventElapsedTime(&gemmTime, start, stop);
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
for (int i = 0; i < problem_count; ++i) {
host_tensor_D1.at(i).sync_host();;
//
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator>
reference_gemm_1;
auto problem_size_0 = problem_sizes_0[i];
auto problem_size_1 = problem_sizes_1[i];
reference_gemm_0(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
ref_A0.at(i),
ref_B0.at(i),
ElementAccumulator(0), //beta = 0
ref_Z.at(i),
ref_Z.at(i),
ElementAccumulator(0)
);
cutlass::reference::device::TensorScaleBiasGemm<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutC
> (
problem_size_0,
ref_Z.at(i),
ref_ref_D0.at(i),
alpha0,
ref_Scale0.at(i),
ref_Bias0.at(i)
);
if(relu) {
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
ref_ref_D0.at(i),
ref_B1.at(i),
beta1,
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
ref_ref_D1.at(i)
);
if(relu) {
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
}
cudaDeviceSynchronize();
host_tensor_ref_D0.at(i).sync_host();
host_tensor_ref_D1.at(i).sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
host_tensor_ref_D1.at(i).host_view(),
host_tensor_D1.at(i).host_view());
CHECK_TRUE(passed);
if (!passed)
{
std::stringstream fname;
fname << "error_B2bGemm_device_fused.txt";
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "GEMM " << i << " in group\n"
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
return false;
}
}
return true;
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -43,6 +43,7 @@
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun
//
B2bInterleavedNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun
//
B2bInterleavedFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
//Reorder B0
cutlass::reorder_column<16>(
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
cutlass::reorder_column<InterleavedK_>(
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
// tensor_Bias0_batched.sync_device();
//
// Initialize the GEMM operator
//
typename B2bGemm::Arguments arguments{
mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
tensor_B1_reordered.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
batch_count,
};
B2bGemm b2b_gemm_op;
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator
>(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
ElementAccumulator(0)
ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
);
cutlass::reference::device::TensorScaleBiasGemm<
cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
alpha1, //intermediate alpha=1
reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(),
cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref()
reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()

View File

@ -119,8 +119,6 @@ template <
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
@ -154,7 +152,6 @@ class B2bGemm {
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp1::kCount;
static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
@ -184,77 +181,11 @@ class B2bGemm {
EpilogueOutputOp1,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator,
SmemAccumulator
>::B2bGemmKernel;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size_0;
GemmCoord problem_size_1;
TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
int split_k_slices_ = 1
):
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
split_k_slices(split_k_slices_) {
}
};
using Arguments = typename B2bGemmKernel::Arguments;
private:
@ -269,10 +200,6 @@ public:
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = B2bGemmKernel::can_implement(
args.problem_size_0,
args.problem_size_1,
@ -295,20 +222,14 @@ public:
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
args.batch_count);
return bytes;
}
@ -320,38 +241,17 @@ public:
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
args.batch_count);
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
// args.problem_size_1,
// args.problem_size_1,
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
// args.split_k_slices);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// args.batch_count);
// Initialize the Params structure
params_ = typename B2bGemmKernel::Params{
args.mode,
args.problem_size_0,
args.problem_size_1,
grid_shape,
@ -363,6 +263,13 @@ public:
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1,
args.batch_stride_A0,
args.batch_stride_B0,
args.batch_stride_B1,
args.batch_stride_C1,
args.batch_stride_D1,
args.batch_stride_Bias0,
args.batch_stride_Scale0,
args.epilogue0,
args.epilogue1,
static_cast<int *>(workspace),
@ -373,12 +280,6 @@ public:
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
@ -430,12 +331,12 @@ public:
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}

View File

@ -220,7 +220,6 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
return pass;
}
int main() {
std::vector<bool (*)()>funcs = {
@ -229,10 +228,6 @@ int main() {
};
return testRun(75, funcs, "conv int8 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -39,7 +39,6 @@
#include "device/b2b_implicit_gemm_convolution.h"
#include "b2b_interleaved_conv2d_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
@ -219,20 +218,13 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
return pass;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
&run_fused_conv2d_fprop_optimized_s8_sm75_shmem
};
return testRun(75, funcs, "conv int8 shmem staging");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,297 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident
*/
#include <iostream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/base_grouped.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "device/b2b_gemm.h"
#include "kernel/default_b2b_gemm.h"
#include "threadblock/grouped_threadblock_swizzle.h"
#include "b2b_grouped_gemm_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_0;
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_1;
// Constraints:
// 1. Warp shape N must equal thread block shape N
// 2. Problem size N must equal thread block shape N
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
int alignment = 8;
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
int problem_count;
bool verbose;
//
// Methods
//
Options():
help(false),
error(false),
reference_check(true),
problem_count(15),
verbose(false)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("problems", problem_count, 15);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("verbose", verbose, false);
randomize_problems(cmd);
}
void randomize_problems(cutlass::CommandLine &cmd) {
//
// For now, randomly choose the problem sizes.
//
int cmd_line_m = -1;
int cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes0.reserve(problem_count);
problem_sizes1.reserve(problem_count);
for (int i = 0; i < problem_count; ++i) {
int m = cmd_line_m;
int k = cmd_line_k;
if (m < 1) {
m = alignment * ((rand() % 256) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 256) + 1);
}
cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k);
cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN);
problem_sizes0.push_back(problem0);
problem_sizes1.push_back(problem1);
}
if (verbose) {
print_problem_sizes();
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
<< " back-to-back GEMMs are subject to.s"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --problems=<int> Number of individual GEMM problems (default: --problems=15)\n"
<< " --m=<int> Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n"
<< " --k=<int> Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n"
<< " --verbose=<bool> If true, prints problem sizes.\n";
out << "\n\nExamples:\n\n"
<< "# Runs a grouped B2b GEMM with 10 random problem sizes\n"
<< "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n";
return out;
}
void print_problem_sizes() {
std::cout << std::endl;
std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl;
for (int i = 0; i < problem_count; ++i) {
cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i);
cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i);
std::cout << "Problem " << i
<< "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k()
<< "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k()
<< std::endl;
}
}
};
bool run_fused_grouped_gemm_f16_sm80_rf_res() {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
InstructionShape::kM * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle<
ThreadblockShape0,
cutlass::layout::RowMajor // LayoutC
>;
const int kAlignment = 128 / cutlass::sizeof_bits<ElementOutput>::value;
const int kStages = 3;
using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm<
cutlass::half_t,
cutlass::layout::RowMajor,
kAlignment,
cutlass::half_t,
cutlass::layout::ColumnMajor,
kAlignment,
cutlass::half_t,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
GroupedThreadblockSwizzle,
kStages,
cutlass::arch::OpMultiplyAdd
>::B2bGemmKernel;
using B2bGemm = cutlass::gemm::device::BaseGrouped<B2bGemmKernel>;
B2bFusedGroupedGemmRun<B2bGemm> fusedGemm;
std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n";
bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main(int argc, char const **args) {
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0;
gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1;
std::vector<bool (*)()>funcs = {
&run_fused_grouped_gemm_f16_sm80_rf_res
};
return testRun(80, funcs, "grouped gemm f16 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -195,7 +195,6 @@ bool run_fused_gemm_s8_rf_res() {
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
@ -204,9 +203,6 @@ int main() {
};
return testRun(75, funcs, "gemm int8 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -43,7 +43,6 @@
#include "device/b2b_gemm.h"
#include "b2b_interleaved_gemm_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
@ -197,18 +196,13 @@ bool run_fused_gemm_s8_shmem() {
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_gemm_s8,
&run_fused_gemm_s8_shmem
};
return testRun(75, funcs, "gemm int8 shmem staing");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
SmemAccumulator,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1
);
if(passed)
std::cout << "Pass\n";
else
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
return passed;
}
bool run_fused_gemm_s8_sm80_rf_res_batch() {
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
const bool SmemAccumulator = false;
using B2bGemm = cutlass::gemm::device::B2bGemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
SmemAccumulator,
16,
16,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
int batch_count = 2;
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_Scale0 = 0;
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1,
cutlass::gemm::GemmUniversalMode::kBatched,
batch_count,
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0
);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_gemm_s8_sm80,
&run_fused_gemm_s8_sm80_rf_res
&run_fused_gemm_s8_sm80_rf_res,
&run_fused_gemm_s8_sm80_rf_res_batch
};
return testRun(80, funcs, "gemm int8 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
const bool SmemAccumulator = true;
using B2bGemm = cutlass::gemm::device::B2bGemm<
@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
SmemAccumulator,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate
>;

View File

@ -40,19 +40,66 @@
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
#include "threadblock/grouped_threadblock_swizzle.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
namespace detail {
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
/// the parameters of the problem visitor used in GroupedParams.
template <
typename B2bMma_,
typename ThreadblockSwizzle_,
typename Enable = void
>
struct ProblemVisitorOrDefault;
/// Return a generic problem visitor for GEMM problems
template <
typename B2bMma_,
typename ThreadblockSwizzle_
>
struct ProblemVisitorOrDefault<B2bMma_,
ThreadblockSwizzle_,
typename platform::enable_if<
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
>::type> {
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
GroupScheduleMode::kDeviceOnly,
128,
128,
platform::is_same<typename B2bMma_::LayoutC,
cutlass::layout::ColumnMajor>::value>;
};
/// Return the problem visitor specified by the swizzling function
template <
typename B2bMma_,
typename ThreadblockSwizzle_
>
struct ProblemVisitorOrDefault<B2bMma_,
ThreadblockSwizzle_,
typename platform::enable_if<
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
>::type> {
using value = typename ThreadblockSwizzle_::ProblemVisitor;
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct B2bGemm {
@ -61,14 +108,184 @@ struct B2bGemm {
using OutputOp0 = typename B2bMma::OutputOp;
using OutputOp1 = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA0 = typename B2bMma::IteratorA0::Element;
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
using ElementB0 = typename B2bMma::IteratorB0::Element;
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
using ElementB1 = typename B2bMma::IteratorB1::Element;
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
/// the second GEMM (other than for ElementA/ElementB)
using ElementA = typename B2bMma::IteratorA0::Element;
using LayoutA = typename B2bMma::IteratorA0::Layout;
using ElementB = typename B2bMma::IteratorB0::Element;
using LayoutB = typename B2bMma::IteratorB0::Layout;
static ComplexTransform const kTransformA = B2bMma::kTransformA;
static ComplexTransform const kTransformB = B2bMma::kTransformB;
using Operator = typename B2bMma::Operator0;
using OperatorClass = typename Operator::OperatorClass;
using ThreadblockShape = typename B2bMma::Shape0;
using WarpShape = typename Operator::Shape;
using InstructionShape = typename Operator::InstructionShape;
using ArchTag = typename B2bMma::ArchTag;
static int const kStages = B2bMma::kStages;
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
using Mma = B2bMma;
using EpilogueOutputOp = OutputOp1;
/// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0;
static int const kThreadCount = 32 * WarpCount0::kCount;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
typename B2bMma::IteratorA0::TensorRef ref_A0;
typename B2bMma::IteratorB0::TensorRef ref_B0;
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
typename B2bMma::IteratorB1::TensorRef ref_B1;
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
int batch_count;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
typename B2bMma::IteratorA0::TensorRef ref_A0_,
typename B2bMma::IteratorB0::TensorRef ref_B0_,
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
typename B2bMma::IteratorB1::TensorRef ref_B1_,
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
int64_t batch_stride_A0_,
int64_t batch_stride_B0_,
int64_t batch_stride_B1_,
int64_t batch_stride_C1_,
int64_t batch_stride_D1_,
int64_t batch_stride_Bias0_,
int64_t batch_stride_Scale0_,
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
int batch_count_ = 1
):
mode(mode_),
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
batch_stride_A0(batch_stride_A0_),
batch_stride_B0(batch_stride_B0_),
batch_stride_B1(batch_stride_B1_),
batch_stride_C1(batch_stride_C1_),
batch_stride_D1(batch_stride_D1_),
batch_stride_Bias0(batch_stride_Bias0_),
batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
batch_count(batch_count_) {
}
};
// Arguments structure for grouped B2B problems
struct GroupedArguments {
GemmCoord* problem_size_0;
GemmCoord* problem_size_1;
typename B2bMma::IteratorA0::TensorRef* ref_A0;
typename B2bMma::IteratorB0::TensorRef* ref_B0;
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
typename B2bMma::IteratorB1::TensorRef* ref_B1;
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
int problem_count;
int threadblock_count;
GemmCoord* host_problem_sizes;
CUTLASS_HOST_DEVICE
GroupedArguments(
int problem_count,
GemmCoord* problem_size_0_,
GemmCoord* problem_size_1_,
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
int threadblock_count = 0
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
problem_count(problem_count),
threadblock_count(threadblock_count)
{}
};
/// Parameters structure
struct Params {
cutlass::gemm::GemmUniversalMode mode;
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
@ -89,6 +306,13 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
int *semaphore;
int gemm_k_iterations_0;
int gemm_k_size_0;
@ -100,11 +324,12 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmUniversalMode mode,
cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
@ -116,14 +341,22 @@ struct B2bGemm {
typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
int64_t batch_stride_A0,
int64_t batch_stride_B0,
int64_t batch_stride_B1,
int64_t batch_stride_C1,
int64_t batch_stride_D1,
int64_t batch_stride_Bias0,
int64_t batch_stride_Scale0,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
int *workspace = nullptr
):
mode(mode),
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
params_A0(ref_A0.layout()),
ref_A0(ref_A0),
params_B0(ref_B0.layout()),
@ -138,6 +371,13 @@ struct B2bGemm {
ref_C1(ref_C1),
params_D1(ref_D1.layout()),
ref_D1(ref_D1),
batch_stride_A0(batch_stride_A0),
batch_stride_B0(batch_stride_B0),
batch_stride_B1(batch_stride_B1),
batch_stride_C1(batch_stride_C1),
batch_stride_D1(batch_stride_D1),
batch_stride_Bias0(batch_stride_Bias0),
batch_stride_Scale0(batch_stride_Scale0),
output_op_0(output_op_0),
output_op_1(output_op_1) {
@ -152,6 +392,81 @@ struct B2bGemm {
}
};
struct GroupedParams {
cutlass::gemm::GemmCoord* problem_size_0;
cutlass::gemm::GemmCoord* problem_size_1;
cutlass::gemm::GemmCoord* grid_tiled_shape;
typename B2bMma::IteratorA0::TensorRef* ref_A0;
typename B2bMma::IteratorB0::TensorRef* ref_B0;
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
typename B2bMma::IteratorB1::TensorRef* ref_B1;
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int* workspace;
CUTLASS_HOST_DEVICE
GroupedParams() {}
CUTLASS_HOST_DEVICE
GroupedParams(
GroupedArguments const &args,
void *workspace = nullptr,
int tile_count = 0
) :
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
workspace(reinterpret_cast<int*>(workspace)) {}
CUTLASS_HOST_DEVICE
void transpose() {
// Only row-major outputs are currently supported, so no transpose is performed
}
/// Returns non-grouped paramaters to be used as input to the kernel-level
/// operator for the problem indicated by problem_visitor.
CUTLASS_HOST_DEVICE
Params to_single_params(const ProblemVisitor& problem_visitor) const {
GemmCoord problem_size0 = problem_visitor.problem_size0();
GemmCoord problem_size1 = problem_visitor.problem_size1();
int32_t idx = problem_visitor.problem_index();
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
return Params(
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size0,
problem_size1,
grid_shape,
ref_A0[idx],
ref_B0[idx],
ref_C0[idx],
ref_Scale0[idx],
ref_Bias0[idx],
ref_B1[idx],
ref_C1[idx],
ref_D1[idx],
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
output_op_0,
output_op_1,
workspace
);
}
};
/// Shared memory storage structure
union SharedStorage {
typename B2bMma::B2bMmaSharedStorage main_loop;
@ -163,7 +478,7 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
B2bGemm() { }
B2bGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
@ -223,7 +538,7 @@ struct B2bGemm {
if(problem_size_0.n() > B2bMma::Shape0::kN)
return Status::kErrorInvalidProblem;
if(problem_size_1.n() > B2bMma::Shape1::kN)
return Status::kErrorInvalidProblem;
@ -233,9 +548,13 @@ struct B2bGemm {
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
run_with_swizzle(params, shared_storage, threadblock_swizzle);
}
/// Executes one GEMM with an externally-provided swizzling function
CUTLASS_DEVICE
void run_with_swizzle(Params const &params, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
@ -247,37 +566,64 @@ struct B2bGemm {
return;
}
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
int offset_k_0 = 0;
int offset_k_1 = 0;
int problem_size_k_0 = params.problem_size_0.k();
int problem_size_k_1 = params.problem_size_1.k();
if (params.mode == GemmUniversalMode::kGemm) {
// Problem size is a function of threadblock index in the K dimension
problem_size_k_0 = min(
problem_size_k_0,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Problem size is a function of threadblock index in the K dimension
problem_size_k_1 = min(
problem_size_k_1,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
};
cutlass::MatrixCoord tb_offset_B0{
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
threadblock_tile_offset.n() * B2bMma::Shape0::kN
};
cutlass::MatrixCoord tb_offset_B1{
threadblock_tile_offset.k() * params.gemm_k_size_1,
offset_k_1,
threadblock_tile_offset.n() * B2bMma::Shape1::kN
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_0 = min(
params.problem_size_0.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_1 = min(
params.problem_size_1.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
// Compute threadblock-scoped matrix multiply-add
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// Compute position within threadblock
@ -286,34 +632,33 @@ struct B2bGemm {
// Construct iterators to A and B operands
typename B2bMma::IteratorA0 iterator_A0(
params.params_A0,
params.ref_A0.data(),
ptr_A0,
{params.problem_size_0.m(), problem_size_k_0},
thread_idx,
tb_offset_A0);
typename B2bMma::IteratorB0 iterator_B0(
params.params_B0,
params.ref_B0.data(),
ptr_B0,
{problem_size_k_0, params.problem_size_0.n()},
thread_idx,
tb_offset_B0);
typename B2bMma::IteratorB1 iterator_B1(
params.params_B1,
params.ref_B1.data(),
ptr_B1,
{problem_size_k_1, params.problem_size_1.n()},
thread_idx,
tb_offset_B1);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
// Construct iterators to accumulator scale/bias vector
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
params.ref_Scale0.data(),
ptr_Scale0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -323,7 +668,7 @@ struct B2bGemm {
);
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
params.ref_Bias0.data(),
ptr_Bias0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -332,14 +677,17 @@ struct B2bGemm {
)
);
//
// Main loop
//
OutputOp0 output_op_0(params.output_op_0);
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
}
// Construct thread-scoped matrix multiply
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
@ -349,11 +697,9 @@ struct B2bGemm {
src_accum.clear();
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
}
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
//
// Epilogue
@ -376,23 +722,32 @@ struct B2bGemm {
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
params.ref_C1.data(),
ptr_C1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
@ -401,21 +756,21 @@ struct B2bGemm {
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D1(
params.params_D1,
params.ref_D1.data(),
ptr_D1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C1 = iterator_D1;
@ -427,14 +782,14 @@ struct B2bGemm {
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
@ -457,4 +812,3 @@ struct B2bGemm {
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,157 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Scheduler for grouped B2b GEMMs
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape,
GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount,
int ThreadCount,
bool Transposed = false>
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount> {
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using BaseParams = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
static bool const kTransposed = Transposed;
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
struct Params {
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
int32_t problem_count;
void const *workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
problem_count(0), workspace(nullptr), tile_count(0) { }
/// Ctor
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const *problem_sizes0,
cutlass::gemm::GemmCoord const *problem_sizes1,
int32_t problem_count,
void const *workspace = nullptr,
int32_t tile_count = 0
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
workspace(workspace),
tile_count(tile_count)
{}
/// Convert the B2b-GEMM-specific parameters to those used by the base class
CUTLASS_HOST_DEVICE
BaseParams to_base() const {
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
// shape of the grid used in the non-grouped B2b GEMM
problem_sizes0,
problem_count,
workspace,
tile_count);
}
};
//
// Methods
//
CUTLASS_DEVICE
B2bGemmGroupedProblemVisitor(
Params const &params_,
SharedStorage &shared_storage_,
int32_t block_idx
): Base (
params_.to_base(),
shared_storage_, block_idx),
problem_sizes0(params_.problem_sizes0),
problem_sizes1(params_.problem_sizes1)
{}
/// Returns the problem size 0 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size0() const {
GemmCoord problem = problem_sizes0[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
/// Returns the problem size 1 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size1() const {
GemmCoord problem = problem_sizes1[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -63,7 +63,9 @@
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "kernel/b2b_gemm.h"
#include "kernel/grouped.h"
#include "threadblock/default_b2b_mma.h"
#include "threadblock/grouped_threadblock_swizzle.h"
////////////////////////////////////////////////////////////////////////////////
@ -73,6 +75,9 @@ namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <typename T>
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
template <
/// Element type for A matrix operand
typename ElementA_,
@ -114,12 +119,12 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Stage accumulator in shared memory
bool SmemAccumulator = false
bool SmemAccumulator = false,
/// Whether or not the operation is grouped
typename Enable = void
>
struct DefaultB2bGemm;
@ -161,17 +166,77 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
/// Partial specialization for Ampere Architecture with grouped operation
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0,
/// Epilogue output operator
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
@ -188,7 +253,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
};
@ -228,8 +295,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -249,8 +314,9 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator
Operator,
false,
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
> {
/// Define the threadblock-scoped matrix multiply-accumulate
@ -274,7 +340,7 @@ struct DefaultB2bGemm<
Operator,
EpilogueOutputOp0
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -287,7 +353,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -323,20 +389,17 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator> {
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -360,7 +423,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -396,19 +459,17 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
ThreadblockSwizzle, 2, Operator, false,
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -418,7 +479,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -430,7 +491,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -112,22 +112,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, true> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture
@ -179,8 +175,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator,
true
> {
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
false,
true
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -277,20 +270,17 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator, true> {
Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -350,19 +340,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
ThreadblockSwizzle, 2, Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,168 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief High-level interface for running a grouped version of a CUTLASS kernel
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// High-level interface for running a grouped version of a CUTLASS kernel
template <
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
>
struct GroupedKernel {
public:
using BaseKernel = BaseKernel_;
using Epilogue = typename BaseKernel::Epilogue;
/// Types that need to be exported to work properly with device::BaseGrouped
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
using Mma = typename BaseKernel::Mma;
using Arguments = typename BaseKernel::GroupedArguments;
using Params = typename BaseKernel::GroupedParams;
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
static int const kThreadCount = BaseKernel::kThreadCount;
/// Shared memory storage structure
struct SharedStorage {
typename BaseKernel::SharedStorage kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GroupedKernel() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
/// Executes a kernel-level GEMM in a loop
CUTLASS_DEVICE
void operator()(Params &params, SharedStorage &shared_storage) {
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
if (ProblemVisitor::kTransposed) {
params.transpose();
}
BaseKernel mma;
// Outer 'persistent' loop to iterate over tiles
while (swizzle.problem_visitor.next_tile()) {
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
// Next tile
swizzle.problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
MatrixCoord output_coord(
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
typename ScalarType, ///< alpha Type
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
int kMblock = 4,
int kNblock = 4
>
__global__ void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRefIn tensor_in, ///< input tensor
TensorRefOut tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
ConvertOp convert_op;
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
int batch_idx = blockIdx.z;
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
int row = row_block + i;
int col = col_block + j;
MatrixCoord coord = MatrixCoord(row, col);
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, coord.column()});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
scale * ScalarType(tensor_in.at(coord)) + bias);
}
}
}
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
int64_t npq = npq_start + m;
thread_n[m] = int(npq / PQ);
int64_t residual = npq % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, thread_k});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, thread_k});
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
scale * ScalarType(
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
) + bias);
}
}
}
}
}
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type
typename ElementOut, ///< Output Type
typename Layout, ///< Layout of input/output tensor
typename ScalarType, ///< alpha Type
typename LayoutScaleBias, ///< Layout of scale and bias
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
>
void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
int const kMblock = 4;
int const kNblock = 4;
dim3 block(16, 8);
dim3 grid(
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
batch_count % std::numeric_limits<uint16_t>::max()
);
kernel::TensorScaleBiasGemmBatched<
TensorRef<ElementIn, Layout>,
TensorRef<ElementOut, Layout>,
ScalarType,
TensorRef<ScalarType, LayoutScaleBias>,
ConvertOp,
kMblock,
kNblock
><<< grid, block >>> (
problem_size,
tensor_in,
tensor_out,
alpha,
tensor_scale,
tensor_bias,
batch_count,
batch_stride_tensor_in,
batch_stride_tensor_out,
batch_stride_tensor_scale,
batch_stride_tensor_bias
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type

View File

@ -119,8 +119,10 @@ public:
using Shape0 = Shape0_;
///< Iterates over tiles of A operand in global memory
using IteratorA0 = IteratorA0_;
using IteratorA = IteratorA0;
///< Iterates over tiles of B operand in global memory
using IteratorB0 = IteratorB0_;
using IteratorB = IteratorB0;
///< Policy describing tuning details
using Policy0 = Policy0_;
@ -139,6 +141,10 @@ public:
using IteratorB1 = IteratorB1_;
///< Policy describing tuning details
using Policy1 = Policy1_;
///< Export Policy0 as the threadblock-level Mma's policy
using Policy = Policy0;
using Shape = Shape0;
using SmemIteratorB1 = SmemIteratorB1_;
@ -188,6 +194,10 @@ public:
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// Internal structure exposed for introspection.
struct Detail {
@ -641,6 +651,11 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
// 2nd Gemm
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
@ -871,7 +886,10 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@ -121,8 +121,10 @@ public:
using Shape0 = Shape0_;
///< Iterates over tiles of A operand in global memory
using IteratorA0 = IteratorA0_;
using IteratorA = IteratorA0;
///< Iterates over tiles of B operand in global memory
using IteratorB0 = IteratorB0_;
using IteratorB = IteratorB0;
///< Iterates over tiles of the scale and bias vectors in global memory
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
///< Policy describing tuning details
@ -141,6 +143,10 @@ public:
///< Policy describing tuning details
using Policy1 = Policy1_;
///< Export Policy0 as the threadblock-level Mma's policy
using Policy = Policy0;
using Shape = Shape0;
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
@ -194,6 +200,10 @@ public:
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// Internal structure exposed for introspection.
struct Detail {
@ -664,6 +674,11 @@ public:
}
// Insert fence and wait for all outstanding cp.async operations to commit.
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
/// Epilogue for the first Implicit Gemm
Epilogue0 epilogue0;
@ -855,7 +870,10 @@ public:
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@ -126,7 +126,9 @@ public:
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA0;
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB0;
using Policy0 = Policy0_; ///< Policy describing tuning details
using SmemIteratorA0 = SmemIteratorA0_;
@ -139,6 +141,8 @@ public:
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
using Policy1 = Policy1_; ///< Policy describing tuning details
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
using Shape = Shape1;
using SmemIteratorB1 = SmemIteratorB1_;
@ -195,6 +199,10 @@ public:
/// Complex transform on B1 operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");

View File

@ -128,7 +128,9 @@ public:
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA0;
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB0;
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
using Policy0 = Policy0_; ///< Policy0 describing tuning details
@ -141,6 +143,8 @@ public:
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
using Policy1 = Policy1_; ///< Policy1 describing tuning details
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
using Shape = Shape1;
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
@ -192,6 +196,10 @@ public:
/// Complex transform on B1 operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");

View File

@ -0,0 +1,125 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements several threadblock-swizzling functions for grouped kernels
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct GroupedThreadblockSwizzleBase {};
/// Helper for determining if a swizzling function is specialized for grouped operation
template <typename ThreadblockSwizzle>
struct IsGroupedSwizzle {
static bool const value = cutlass::platform::is_base_of<GroupedThreadblockSwizzleBase, ThreadblockSwizzle>::value;
};
} // namespace detail
/// Swizzling function for grouped kernels
template <typename ProblemVisitor_>
struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase {
using ProblemVisitor = ProblemVisitor_;
ProblemVisitor problem_visitor;
CUTLASS_HOST_DEVICE
GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params,
typename ProblemVisitor::SharedStorage& shared_storage,
int block_idx) : problem_visitor(params, shared_storage, block_idx) {}
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset(int /*log_tile*/) const {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
return GemmCoord(int(threadblock_idx / grid_shape.n()),
int(threadblock_idx % grid_shape.n()),
0);
}
/// Dummy method to satisfy API for threadblock swizzling functions
CUTLASS_HOST_DEVICE
static int get_log_tile(GemmCoord /*tiled_shape*/) {
return 0;
}
};
template <
typename ThreadblockShape,
typename LayoutC,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
int PrefetchTileCount = 128,
int ThreadCount = PrefetchTileCount>
struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount,
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
>
> {
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount,
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
CUTLASS_HOST_DEVICE
B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
typename Base::ProblemVisitor::SharedStorage& shared_storage,
int block_idx) : Base(params, shared_storage, block_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -31,83 +31,181 @@
/**
This example shows how to run convolution kernels using functions and data structures
provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU.
This example shows how to run CUTLASS's convolution kernels
based on the Implicit GEMM algorithm, that use the Tensor Cores
on an NVIDIA Ampere GPU.
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
of GPU easily.
Writing a single high-performance convolution kernel is hard enough,
let alone writing kernels that perform well for multiple problem sizes
and use good software abstractions.
CUTLASS provides simplified abstractions
to compose multiple sections of a convolution kernel.
When used properly, the kernels can reach peak GPU performance.
CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
threadblock-tile (tile size computed by a threadblock).
CUTLASS divides a kernel into hierarchical composable sections
for each level of the GPU hardware hierarchy:
thread, warp, and threadblock.
Each section computes on its own tile shape,
with each higher level's tile shape
being composed from lower-level tile shapes.
Multiple thread tiles (the tile shape each thread computes)
can be used to form warp tiles (the tile shape each warp computes),
and multiple warp tiles can be used to compute threadblock tiles
(the tile shape computed by a threadblock).
In thie example, we split variable initialization into
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
can view them (logical to physical mapping)
2. Setting up computation properties : describes how the above set tensors will be used to compute
output of convolution.
In this example, we split variable initialization into two parts.
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS,
the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as
alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as
cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float),
ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t),
ElementOutput (float). Communicating just the data type is not enough. As the data is laid out
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of
elements per vector memory access (8), data type of accumulator (float) and data type of
computation of linear combination (alpha * X + beta * C).
1. Setting up data properties: describes how tensors are laid out in the memory
and how the kernel can view them (logical to physical mapping)
Now that we setup the properties of data, we have to setup properties of computation.
2. Setting up computation properties: describes how the above tensors
will be used to compute the output of convolution
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64,
64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
data in bank-conflict free manner, and ton of other variables required to compose, initialize and
launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer
from understanding and coding complicated hardware optimizations which can easily go wrong.
We begin by setting up the data types
of all the input and output elements of a convolution.
A convolution computes
C = alpha * Conv2dFprop(A, B) + beta * C,
so we set up data types for the input tensor A,
weights tensor B, output tensor C,
and the scaling factors alpha and beta.
CUTLASS divides the convolution into two parts:
the "mainloop" that computes X = Conv2dFprop(A, B),
and the "epilogue" that computes C = alpha * X + beta * C.
The epilogue is an element-wise operation on X and C.
In this case, it is a linear combination,
but other epilogues are possible.
CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
constitute the whole process of loading input data from global memory to shared memory, loading data
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
sequence shows a typical mma multistage pipeline.
(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h)
In this example, we want
tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers
--mma--> registers --global stores--> output to global memory
* the scaling factors alpha and beta to be float,
NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies.
* the elements of A and B to be cutlass::half_t
(a 16-bit floating-point type),
* the elements of C to be float, and
There are few more template variables initialized such as, which threadblock tile of output matrix
is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on.
* intermediate sums to be accumulated in float.
These are all put together to create a template variable which describes CUTLASS Implicit GEMM
kernel using cutlass::conv::device::ImplicitGemm template.
We convey this to the CUTLASS kernel
by setting the following template parameters.
The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it.
We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
in the way of learning CUTLASS.
* alpha and beta: ElementComputeEpilogue = float
Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
arguments created to initialize CUTLASS kernel then, the kernel is launched.
* Elements of input tensor A: ElementInputA = cutlass::half_t
In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to
compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel.
* Elements of input tensor B: ElementInputB = cutlass::half_t
* Elements of output tensor C: ElementOutput = float
* Accumulation type: ElementAccumulator = float
Next, we describe the layout of the input and output tensors.
We convey this to the CUTLASS kernel
by setting the following template parameters.
* Layout of input tensor A: LayoutInputA = TensorNHWC
* Layout of input tensor B: LayoutInputB = TensorNHWC
* Layout of output tensor C: LayoutOutput = TensorNHWC
After that, we set up rules to compute the epilogue.
The epilogue in this case is a simple linear combination
C = alpha * X + beta * C.
Thus, we set the kernel's template parameter EpilogueOp
to LinearCombination. LinearCombination itself
has template parameters:
* the element type of the output tensor (ElementOutput),
* the number of elements per vector memory access (8),
* the data type of the accumulator (ElementAccumulator),
* and the data type used to compute the linear combination
(ElementComputeEpilogue).
We then define the tile shapes
that each level of the computation uses.
We define these as types that encode the tile shapes
as compile-time integer values.
Each shape expresses the dimensions M x N x K.
Here, the letters refer to the dimensions
of a matrix-matrix multiply.
* ThreadblockShape defines the threadblock tile shape
as 128 x 128 x 64.
* WarpShape defines the warp tile shape as 64 x 64 x 64.
* InstructionShape defines the MMA
(matrix multiply-accumulate) operation shape
as 16 x 8 x 16.
These types become template arguments
of the kernel properties type
cutlass::conv::kernel::DefaultConv2dFprop.
The kernel uses these shapes to deduce
the number of threads needed per threadblock,
the required amount of shared memory,
the internal layouts needed to access
shared memory without bank conflicts,
and many other properties that the kernel needs
for good performance.
CUTLASS deduces all these properties automatically,
so that users don't have to.
DefaultConv2dFprop accepts other template parameters
that describe things like the target CUDA SM architecture.
CUTLASS also supports multiple MMA pipelines in a threadblock.
An MMA pipeline constitutes the whole process
of loading input data from global memory to shared memory,
loading data from shared memory to registers,
doing matrix multiplication,
and storing the result to global memory.
The below flow sequence shows a typical MMA multistage pipeline
(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h).
tensor in global memory
--cp_async-->
tile in shared memory
--smem loads-->
registers
--mma-->
registers
--global stores-->
output to global memory
On NVIDIA Ampere, the kernel uses `cp_async`
to build a multistage software pipeline.
This helps it better hide latency.
At this point, we can define the actual CUTLASS kernel type
as the alias ImplicitGemm, a specialization of
cutlass::conv::device::ImplicitGemmConvolution.
The latter accepts the kernel properties type alias
Conv2dFpropKernel as its one template argument.
This example then sets up a test problem
and arguments to the kernel.
We use CUTLASS utilities to allocate
the input and output tensors
and fill them with sample input data.
We then create the kernel arguments
as an instance of ImplicitGemm::Arguments.
The arguments include
the problem size (N = 1, H = 64, W = 64, C = 128),
filter size (K = 64, R = 3, S = 3, C = 128),
padding, strides, dilation, tensors, alpha, beta,
and the split k-dimension factor.
We also query CUTLASS if the kernel we instantiated
requires any memory for scratch space.
If yes, we reserve scratch space and pass it along
with other arguments to initialize the CUTLASS kernel.
After lauching the CUTLASS kernel, this example runs
a reference convolution kernel (from CUTLASS utilities)
to check correctness.
*/
#include <iostream>
@ -131,8 +229,8 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
// Data types for input and output tensors
// and computation between elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
@ -143,39 +241,40 @@ using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
// Whether to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
// SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape
// Threadblock tile shape
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape
// Warp tile shape
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
// MMA (Tensor Core instruction, in this case) tile shape
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
// This code section describes how threadblocks are scheduled on GPU
// How the kernel schedules threadblocks
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Number of pipelines you want to use
// Number of pipeline stages to use
constexpr int NumStages = 3;
// This code section describe iterator algorithm selected is Analytic or Optimized
// Which iterator algorithm to use: Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
// This code section describes the epilogue part of the kernel, we use default value
// The epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
// Kernel properties type
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
@ -193,6 +292,7 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
IteratorAlgorithm
>::Kernel;
// Type of the actual kernel
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -230,7 +330,7 @@ struct Options {
beta(0),
benchmark(false) { }
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
// Verify that the problem size is compatible with CUTLASS's convolution implementation
bool valid() {
//
@ -256,7 +356,7 @@ struct Options {
return true;
}
/// Updates input and filter sizes
/// Update input and filter sizes
void update(
cutlass::Tensor4DCoord input_size,
cutlass::Tensor4DCoord filter_size) {
@ -270,7 +370,7 @@ struct Options {
padding.c() = filter_size.w() / 2;
}
// Parses the command line
// Parse command-line arguments
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
@ -302,11 +402,11 @@ struct Options {
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
@ -320,12 +420,12 @@ struct Options {
}
}
/// Prints the usage statement.
/// Print an explanation of the command-line arguments
std::ostream & print_usage(std::ostream &out) const {
out << "16_ampere_tensorop_conv2dfprop example\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types\n"
<< " to compute forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
@ -350,7 +450,7 @@ struct Options {
return out;
}
/// Computes the output tensor size (NPQK)
cutlass::Tensor4DCoord output_size() const {
return cutlass::Tensor4DCoord(
@ -360,19 +460,20 @@ struct Options {
filter_size.n());
}
/// Compute performance in GFLOP/s
/// Compute performance in Gflop/s
///
/// Gflop/s stands for billions (10^9) of
/// floating-point operations per second (Gflop/s).
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
@ -380,14 +481,14 @@ struct Result {
cutlass::Status reference_check;
cudaError_t error;
Result():
runtime_ms(0),
Result():
runtime_ms(0),
gflops(0),
status(cutlass::Status::kSuccess),
reference_check(cutlass::Status::kInvalid),
error(cudaSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
static std::ostream& print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
@ -404,7 +505,7 @@ struct Result {
out << options.tag << ",";
}
out
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
@ -420,8 +521,6 @@ struct Result {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
@ -441,7 +540,7 @@ Result profile_convolution(Options const &options) {
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
// Fill tensor A on host with uniformly distributed random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
@ -449,7 +548,7 @@ Result profile_convolution(Options const &options) {
ElementInputA(-8),
0);
// Fill tensor B on host with uniform-distribution random data
// Fill tensor B on host with uniformly distributed random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
@ -457,7 +556,7 @@ Result profile_convolution(Options const &options) {
ElementInputB(-8),
0);
// Fill tensor C on host with uniform-distribution random data
// Fill tensor C on host with uniformly distributed random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
@ -490,7 +589,7 @@ Result profile_convolution(Options const &options) {
int split_k_slices = 1;
// Construct Conv2dProblemSize with user defined output size
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
@ -501,7 +600,7 @@ Result profile_convolution(Options const &options) {
split_k_slices
);
// Construct ImplicitGemm::Argument structure with conv2d
// Construct ImplicitGemm::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename ImplicitGemm::Arguments arguments{
problem_size,
@ -539,7 +638,7 @@ Result profile_convolution(Options const &options) {
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on host...\n";
@ -552,8 +651,7 @@ Result profile_convolution(Options const &options) {
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
ElementAccumulator
>(
problem_size,
tensor_a.host_ref(),
@ -564,7 +662,7 @@ Result profile_convolution(Options const &options) {
options.beta
);
// Check if output from CUTLASS kernel and reference kernel are equal or not
// Check if CUTLASS kernel and reference kernel produced the same output
tensor_d.sync_host();
bool passed = cutlass::reference::host::TensorEquals(
@ -589,14 +687,14 @@ Result profile_convolution(Options const &options) {
std::stringstream ss;
ss << "16_ampere_workspace_conv2dfprop_"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
@ -616,7 +714,7 @@ Result profile_convolution(Options const &options) {
if (options.measure_performance) {
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
@ -632,7 +730,7 @@ Result profile_convolution(Options const &options) {
return result;
}
// Launch a sequence of implicit GEMM operations on the device
// Launch a sequence of implicit GEMM operations on the device.
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
CUTLASS_CHECK(result.status);
@ -652,7 +750,7 @@ Result profile_convolution(Options const &options) {
return result;
}
// Measure elapsed runtime
// Measure elapsed runtime.
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
@ -660,7 +758,7 @@ Result profile_convolution(Options const &options) {
return result;
}
// Print average runtime and GFLOPs.
// Print average run time and floating-point throughput (Gflop/s).
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
@ -673,8 +771,6 @@ Result profile_convolution(Options const &options) {
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
bool notSupported = false;
@ -701,7 +797,7 @@ int main(int argc, char const **args) {
}
Options options;
options.parse(argc, args);
if (options.help) {
@ -768,5 +864,3 @@ int main(int argc, char const **args) {
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -470,8 +470,7 @@ Result profile_convolution(Options const &options) {
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
ElementAccumulator
>(
problem_size,
tensor_a.host_ref(),

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
/**
The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when
The example demonstrates how to reduce one of the operands of the GEMM along the k-dimension when
computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere
16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
core instructions.

View File

@ -31,6 +31,7 @@
cutlass_example_add_executable(
24_gemm_grouped
gemm_grouped.cu
gemm_grouped.cu
)

View File

@ -37,7 +37,7 @@
leading dimensions and problem sizes are stored in arrays in GMEM.
This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM
concept may be distinct.
concept may be distinct.
This benchmark program initializes a workspace with random problem sizes for a given number of
groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to
@ -186,7 +186,7 @@ struct Options {
//
// Methods
//
//
Options():
help(false),
@ -216,7 +216,7 @@ struct Options {
cmd.get_cmd_line_argument("alignment", alignment, 8);
cmd.get_cmd_line_argument("groups", problem_count, 15);
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
cmd.get_cmd_line_argument("beta", beta, 0.0f);
cmd.get_cmd_line_argument("beta", beta, 0.0f);
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("streams", cuda_streams, 0);
cmd.get_cmd_line_argument("verbose", verbose, false);
@ -455,13 +455,13 @@ struct Options {
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
// Number of real-valued multiply-adds
int64_t fmas = int64_t();
for (auto const & problem : problem_sizes) {
fmas += problem.product();
}
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
@ -546,7 +546,7 @@ public:
template <typename Element>
void initialize_tensor(
Element *ptr,
size_t capacity,
size_t capacity,
cutlass::Distribution::Kind dist_kind,
uint32_t seed) {
@ -578,7 +578,7 @@ public:
cutlass::reference::device::BlockFillRandomUniform(
ptr, capacity, seed, scope_max, scope_min, 0);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::device::BlockFillRandomGaussian(
@ -589,7 +589,7 @@ public:
// Fill with increasing elements
cutlass::reference::device::BlockFillSequential(
ptr, capacity, Element(1), Element());
}
}
else {
// Fill with all 1s
@ -674,13 +674,13 @@ public:
ptr_A.reset(problem_count());
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(problem_count());
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(problem_count());
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(problem_count());
ptr_D.copy_from_host(ptr_D_host.data());
@ -712,7 +712,7 @@ public:
MatrixCoord extent_A{problem.m(), problem.k()};
MatrixCoord extent_B{problem.k(), problem.n()};
MatrixCoord extent_C{problem.m(), problem.n()};
cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get() + offset_A.at(i), layout_A, extent_A);
cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get() + offset_B.at(i), layout_B, extent_B);
cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get() + offset_C.at(i), layout_C, extent_C);
@ -724,18 +724,18 @@ public:
cutlass::reference::device::GemmComplex<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementC, LayoutC,
ElementCompute, ElementAccumulator
>(
problem,
options.alpha,
options.alpha,
view_A,
Gemm::kTransformA,
view_B,
Gemm::kTransformB,
options.beta,
view_C,
view_Ref_device,
options.beta,
view_C,
view_Ref_device,
ElementAccumulator(0)
);
@ -781,8 +781,8 @@ public:
std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n";
for (auto const & bin : this->options.problem_bins) {
std::cout << " [" << bin_idx << "]: "
<< bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k()
std::cout << " [" << bin_idx << "]: "
<< bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k()
<< ", batch count: " << bin.second.size() << "\n";
++bin_idx;
@ -832,7 +832,7 @@ public:
for (auto const & bin : this->options.problem_bins) {
int first_idx = bin.second.front();
bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx));
bin_count.push_back(int32_t(bin.second.size()));
@ -974,7 +974,7 @@ public:
std::cerr << "CUTLASS error on line " << __LINE__ << std::endl;
return result;
}
}
//
@ -1027,7 +1027,7 @@ public:
int last_stream_idx = 0;
for (int iter = 0; iter < this->options.iterations; ++iter) {
for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) {
cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx];
@ -1098,7 +1098,7 @@ public:
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
//
// Wait for work to be completed
//
@ -1129,10 +1129,10 @@ public:
for (auto event : events) {
(void)cudaEventDestroy(event);
}
for (auto stream : cuda_streams) {
if (stream) {
(void)cudaStreamDestroy(stream);
(void)cudaStreamDestroy(stream);
}
}
@ -1203,8 +1203,8 @@ public:
int tiles = Gemm::problem_tile_count(problem);
total_tiles += tiles;
std::cout << " [" << idx << "]: "
<< problem.m() << "-by-" << problem.n() << "-by-" << problem.k()
std::cout << " [" << idx << "]: "
<< problem.m() << "-by-" << problem.n() << "-by-" << problem.k()
<< " (" << tiles << " threadblock tiles)" << "\n";
++idx;
@ -1442,12 +1442,12 @@ int main(int argc, char const **args) {
}
if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
//
// This example requires an NVIDIA Ampere-architecture GPU.
//
std::cout
std::cout
<< "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or "
<< "later (compute capability 80 or greater).\n";
@ -1497,9 +1497,9 @@ int main(int argc, char const **args) {
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
@ -1519,8 +1519,8 @@ int main(int argc, char const **args) {
cutlass::ComplexTransform::kNone,
8,
ElementOutput, LayoutC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
@ -1531,7 +1531,7 @@ int main(int argc, char const **args) {
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
4>::GemmKernel;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;

View File

@ -181,7 +181,7 @@ struct Options {
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_complex_gemm --m=1024 --n=512 \\\n"
<< "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm --m=1024 --n=512 \\\n"
<< " --alpha=2 --beta=0.707 \n\n";
return out;

View File

@ -27,9 +27,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
cutlass_example_add_executable(
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm
29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu
29_3xtf32_complex_gemm
29_3xtf32_complex_gemm.cu
)

View File

@ -34,7 +34,7 @@
matrix multiply kernel to verify its correctness.
The CUTLASS Syrk template is instantiated in the function CutlassSsyrkNN. This is kernel computes
the symmetric rank-k update (SYRK) using double-precision doubleing-point arithmetic and assumes
the symmetric rank-k update (SYRK) using double-precision floating-point arithmetic and assumes
all matrices have column-major layout.
The threadblock tile size is chosen as 16x32x16 which offers good performance for large matrices.

View File

@ -34,7 +34,7 @@
matrix multiply kernel to verify its correctness.
The CUTLASS Trmm template is instantiated in the function CutlassStrmmNN. This is kernel computes
the triangular matrix product (TRMM) using double-precision doubleing-point arithmetic and assumes
the triangular matrix product (TRMM) using double-precision floating-point arithmetic and assumes
all matrices have column-major layout.
The threadblock tile size is chosen as 64x64x16 which offers good performance for large matrices.

View File

@ -578,9 +578,21 @@ public:
int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage));
cudaError_t result;
if (gemm_smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
gemm_smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<gemm_grid, gemm_block, gemm_smem_size, stream>>>(params_.gemm);
cudaError_t result = cudaGetLastError();
result = cudaGetLastError();
if (result != cudaSuccess) {
return cutlass::Status::kErrorInternal;

View File

@ -316,7 +316,11 @@ int run(Options &options) {
// <- Fill tensor_b_indices on host with unique random integers
std::vector<int> to_fill(problem_size.n()) ; // vector with ints.
std::iota (std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n()
std::random_shuffle(to_fill.begin(), to_fill.end());
{ // std::random_shuffle was deprecated in C++14 and removed in C++17
std::random_device make_seed;
std::mt19937 source_of_randomness(make_seed());
std::shuffle(to_fill.begin(), to_fill.end(), source_of_randomness);
}
memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int));
// Copy data from host to GPU

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,510 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines additional layout functions used in Permute GEMM example to simplify
computing reference permutations of 4/5D tensors when source data is column-major.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include "assert.h"
#endif
#include "cutlass/cutlass.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/coord.h"
#include "cutlass/tensor_coord.h"
namespace cutlass {
namespace layout {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Mapping function for 4-D CWHN tensors.
class TensorCWHN {
public:
/// Logical rank of tensor
static int const kRank = 4;
/// Rank of stride vector
static int const kStrideRank = 3;
/// Index type used for coordinates
using Index = int32_t;
/// Long index type used for offsets
using LongIndex = int64_t;
/// Logical coordinate (n, h, w, c)
using TensorCoord = Tensor4DCoord;
/// Stride vector
using Stride = Coord<kStrideRank>;
private:
//
// Data members
//
/// Stride data member - [n, hn, whn]
Stride stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_HOST_DEVICE
TensorCWHN(Stride const &stride = Stride(0)): stride_(stride) { }
/// Constructor
CUTLASS_HOST_DEVICE
TensorCWHN(
typename Stride::Index stride_h, ///< number of elements between adjacent N coordinates
typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates
typename Stride::Index stride_c ///< number of elements between adjacent W coordinates
):
stride_(make_Coord(stride_h, stride_w, stride_c)) { }
/// Constructor
// Once convolutions implement 64b stride this ctor can be deleted
CUTLASS_HOST_DEVICE
TensorCWHN(Coord<kStrideRank, LongIndex> const &stride):
stride_(make_Coord(
static_cast<typename Stride::Index>(stride[0]),
static_cast<typename Stride::Index>(stride[1]),
static_cast<typename Stride::Index>(stride[2]))
) { }
/// Helper returns a layout to a tightly packed WCNH tensor.
CUTLASS_HOST_DEVICE
static TensorCWHN packed(TensorCoord const &extent) {
return TensorCWHN(
make_Coord(
extent.n(),
extent.h() * extent.n(),
extent.w() * extent.h() * extent.n()
)
);
}
/// Returns the offset of a coordinate (n, h, w, c) in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(TensorCoord const &coord) const {
return coord.n() +
LongIndex(stride_[0] * coord.h()) +
LongIndex(stride_[1] * coord.w()) +
LongIndex(stride_[2] * coord.c());
}
/// Returns the offset of a pitchlinear coordinate in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(PitchLinearCoord coord) const {
return coord.contiguous() + LongIndex(coord.strided() * stride_[2]);
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride stride() const {
return stride_;
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride & stride() {
return stride_;
}
/// Compute the number of contiguous elements needed to store a tensor with the given size
CUTLASS_HOST_DEVICE
LongIndex capacity(TensorCoord const &extent) const {
// it does not make sense if the extent is larger than stride
// and we could not rely on the capacity calculation in such cases
// we could move this checkers to debug code only
if ((extent.n() > stride_[0])
|| (extent.h() * stride_[0] > stride_[1])
|| (extent.w() * stride_[1] > stride_[2])) {
assert(0);
}
return extent.c() * stride_[2];
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Mapping function for 4-D NHCW tensors.
class TensorNHCW {
public:
/// Logical rank of tensor
static int const kRank = 4;
/// Rank of stride vector
static int const kStrideRank = 3;
/// Index type used for coordinates
using Index = int32_t;
/// Long index type used for offsets
using LongIndex = int64_t;
/// Logical coordinate (n, h, w, c)
using TensorCoord = Tensor4DCoord;
/// Stride vector
using Stride = Coord<kStrideRank>;
private:
//
// Data members
//
/// Stride data member - [w, cw, hcw]
Stride stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_HOST_DEVICE
TensorNHCW(Stride const &stride = Stride(0)): stride_(stride) { }
/// Constructor
CUTLASS_HOST_DEVICE
TensorNHCW(
typename Stride::Index stride_c, ///< number of elements between adjacent C coordinates
typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates
typename Stride::Index stride_n ///< number of elements between adjacent N coordinates
):
stride_(make_Coord(stride_c, stride_h, stride_n)) { }
/// Constructor
// Once convolutions implement 64b stride this ctor can be deleted
CUTLASS_HOST_DEVICE
TensorNHCW(Coord<kStrideRank, LongIndex> const &stride):
stride_(make_Coord(
static_cast<typename Stride::Index>(stride[0]),
static_cast<typename Stride::Index>(stride[1]),
static_cast<typename Stride::Index>(stride[2]))
) { }
/// Helper returns a layout to a tightly packed WCNH tensor.
CUTLASS_HOST_DEVICE
static TensorNHCW packed(TensorCoord const &extent) {
return TensorNHCW(
make_Coord(
extent.w(),
extent.c() * extent.w(),
extent.h() * extent.c() * extent.w()
)
);
}
/// Returns the offset of a coordinate (n, h, w, c) in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(TensorCoord const &coord) const {
return coord.w() +
LongIndex(stride_[0] * coord.c()) +
LongIndex(stride_[1] * coord.h()) +
LongIndex(stride_[2] * coord.n());
}
/// Returns the offset of a pitchlinear coordinate in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(PitchLinearCoord coord) const {
return coord.contiguous() + LongIndex(coord.strided() * stride_[2]);
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride stride() const {
return stride_;
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride & stride() {
return stride_;
}
/// Compute the number of contiguous elements needed to store a tensor with the given size
CUTLASS_HOST_DEVICE
LongIndex capacity(TensorCoord const &extent) const {
// it does not make sense if the extent is larger than stride
// and we could not rely on the capacity calculation in such cases
// we could move this checkers to debug code only
if ((extent.w() > stride_[0])
|| (extent.c() * stride_[0] > stride_[1])
|| (extent.h() * stride_[1] > stride_[2])) {
assert(0);
}
return extent.n() * stride_[2];
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Mapping function for 4-D NHCW tensors.
class TensorNCWH {
public:
/// Logical rank of tensor
static int const kRank = 4;
/// Rank of stride vector
static int const kStrideRank = 3;
/// Index type used for coordinates
using Index = int32_t;
/// Long index type used for offsets
using LongIndex = int64_t;
/// Logical coordinate (n, h, w, c)
using TensorCoord = Tensor4DCoord;
/// Stride vector
using Stride = Coord<kStrideRank>;
private:
//
// Data members
//
/// Stride data member - [h, wh, cwh]
Stride stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_HOST_DEVICE
TensorNCWH(Stride const &stride = Stride(0)): stride_(stride) { }
/// Constructor
CUTLASS_HOST_DEVICE
TensorNCWH(
typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates
typename Stride::Index stride_c, ///< number of elements between adjacent H coordinates
typename Stride::Index stride_n ///< number of elements between adjacent N coordinates
):
stride_(make_Coord(stride_w, stride_c, stride_n)) { }
/// Constructor
// Once convolutions implement 64b stride this ctor can be deleted
CUTLASS_HOST_DEVICE
TensorNCWH(Coord<kStrideRank, LongIndex> const &stride):
stride_(make_Coord(
static_cast<typename Stride::Index>(stride[0]),
static_cast<typename Stride::Index>(stride[1]),
static_cast<typename Stride::Index>(stride[2]))
) { }
/// Helper returns a layout to a tightly packed WCNH tensor.
CUTLASS_HOST_DEVICE
static TensorNCWH packed(TensorCoord const &extent) {
return TensorNCWH(
make_Coord(
extent.h(),
extent.w() * extent.h(),
extent.c() * extent.w() * extent.h()
)
);
}
/// Returns the offset of a coordinate (n, h, w, c) in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(TensorCoord const &coord) const {
return coord.h() +
LongIndex(stride_[0] * coord.w()) +
LongIndex(stride_[1] * coord.c()) +
LongIndex(stride_[2] * coord.n());
}
/// Returns the offset of a pitchlinear coordinate in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(PitchLinearCoord coord) const {
return coord.contiguous() + LongIndex(coord.strided() * stride_[2]);
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride stride() const {
return stride_;
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride & stride() {
return stride_;
}
/// Compute the number of contiguous elements needed to store a tensor with the given size
CUTLASS_HOST_DEVICE
LongIndex capacity(TensorCoord const &extent) const {
// it does not make sense if the extent is larger than stride
// and we could not rely on the capacity calculation in such cases
// we could move this checkers to debug code only
if ((extent.h() > stride_[0])
|| (extent.w() * stride_[0] > stride_[1])
|| (extent.c() * stride_[1] > stride_[2])) {
assert(0);
}
return extent.n() * stride_[2];
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Mapping function for 5-D CWHDN tensors.
class TensorCWHDN {
public:
/// Logical rank of tensor
static int const kRank = 5;
/// Rank of stride vector
static int const kStrideRank = 4;
/// Index type used for coordinates
using Index = int32_t;
/// Long index type used for offsets
using LongIndex = int64_t;
/// Logical coordinate (n, d, h, w, c)
using TensorCoord = Tensor5DCoord;
/// Stride vector
using Stride = Coord<kStrideRank>;
private:
//
// Data members
//
/// Stride data member - [n, dn, hdn, whdn]
Stride stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_HOST_DEVICE
TensorCWHDN(Stride const &stride = Stride(0)): stride_(stride) { }
/// Constructor
CUTLASS_HOST_DEVICE
TensorCWHDN(
typename Stride::Index n,
typename Stride::Index dn,
typename Stride::Index hdn,
typename Stride::Index whdn):
stride_(make_Coord(n, dn, hdn, whdn)) { }
/// Constructor
// Once convolutions implement 64b stride this ctor can be deleted
CUTLASS_HOST_DEVICE
TensorCWHDN(Coord<kStrideRank, LongIndex> const &stride):
stride_(make_Coord(
static_cast<typename Stride::Index>(stride[0]),
static_cast<typename Stride::Index>(stride[1]),
static_cast<typename Stride::Index>(stride[2]),
static_cast<typename Stride::Index>(stride[3]))
) { }
/// Helper returns a layout to a tightly packed CWHDN tensor.
CUTLASS_HOST_DEVICE
static TensorCWHDN packed(TensorCoord const &extent) {
return TensorCWHDN(
make_Coord(
extent.n(),
extent.d() * extent.n(),
extent.h() * extent.d() * extent.n(),
extent.w() * extent.h() * extent.d() * extent.n()
)
);
}
/// Returns the offset of a coordinate (n, d, h, w, c) in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(TensorCoord const &coord) const {
return coord.n() +
LongIndex(stride_[0] * coord.d()) +
LongIndex(stride_[1] * coord.h()) +
LongIndex(stride_[2] * coord.w()) +
LongIndex(stride_[3] * coord.c());
}
/// Returns the offset of a pitchlinear coordinate in linear memory.
CUTLASS_HOST_DEVICE
LongIndex operator()(PitchLinearCoord coord) const {
return coord.contiguous() + LongIndex(coord.strided() * stride_[3]);
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride stride() const {
return stride_;
}
/// Returns the stride of the layout
CUTLASS_HOST_DEVICE
Stride & stride() {
return stride_;
}
/// Compute the number of contiguous elements needed to store a tensor with the given size
CUTLASS_HOST_DEVICE
LongIndex capacity(TensorCoord const &extent) const {
// it does not make sense if the extent is larger than stride
// and we could not rely on the capacity calculation in such cases
// we could move this checkers to debug code only
if ((extent.n() > stride_[0])
|| (extent.d() * stride_[0] > stride_[1])
|| (extent.h() * stride_[1] > stride_[2])
|| (extent.w() * stride_[2] > stride_[3])) {
assert(0);
}
return extent.c() * stride_[3];
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layout
} // namespace cutlass

View File

@ -0,0 +1,344 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Contains additional metadata about layout permute functions used in the example.
*/
#include "cutlass/tensor_coord.h"
#include "cutlass/layout/permute.h"
/// Additional permutation metadata to facilitate testing/printing
template<typename PermuteLayout>
struct PermuteInfo;
/// Specialization for default case (no permute). Other specializations must follow this template.
template<>
struct PermuteInfo<cutlass::layout::NoPermute> {
/// Whether this is a BMM or GEMM permutation (NoPermute can actually be either)
static bool constexpr kBatched = false;
/// Minimal divisor for row extent
static int constexpr kRowFactor = 1;
/// Minimum divisor for column extent
static int constexpr kColumnFactor = 1;
/// Minimum divisor for batch size dimension
static int constexpr kBatchFactor = 1;
/// Tensor layout used in permutation operation
using Layout = cutlass::layout::PackedVectorLayout;
static std::string name() {
return "NoPermute";
}
/// User-friendly description of the permute operation
static std::string desc() {
return "no permutation";
}
/// Infer original higher-rank tensor shape from GEMM/BMM matrix extents.
/// For direct (output) permutations, must be a simple reshape of extent.
/// For inverse (input) permutations, must return shape *before* permute operation.
/// In case of NoPermute, simply use a linear (rank 1) view of the memory
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
return Layout::TensorCoord(extent.row() * extent.column() * batch_count);
}
/// Compute the permuted higher-rank tensor shape from the original shape.
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
return s;
}
};
template<int D1>
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>> {
static bool constexpr kBatched = true;
static int constexpr kRowFactor = 1;
static int constexpr kColumnFactor = 1;
static int constexpr kBatchFactor = D1;
using Layout = cutlass::layout::TensorNHWC;
static std::string name() {
return "Tensor4DPermuteBMM0213<" + std::to_string(D1) + ">";
}
static std::string desc() {
return "batched GEMM permutation [0, 2, 1, 3]";
}
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int D0 = batch_count / D1;
int D2 = extent.row();
int D3 = extent.column();
return {D0, D1, D2, D3};
}
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
return {s[0], s[2], s[1], s[3]};
}
};
template<int D1>
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajorInverse<D1>>
: public PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>> {
static bool constexpr kBatched = true;
static int constexpr kRowFactor = 1;
static int constexpr kColumnFactor = D1;
static int constexpr kBatchFactor = 1;
using Base = PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>>;
using Layout = typename Base::Layout;
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int D0 = batch_count;
int D2 = extent.row();
int D3 = extent.column() / D1;
return {D0, D1, D2, D3};
}
};
template<int D1>
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>> {
static bool constexpr kBatched = true;
static int constexpr kRowFactor = 1;
static int constexpr kColumnFactor = 1;
static int constexpr kBatchFactor = D1;
using Layout = cutlass::layout::TensorNHCW;
static std::string name() {
return "Tensor4DPermuteBMM0321<" + std::to_string(D1) + ">";
}
static std::string desc() {
return "batched GEMM permutation [0, 3, 2, 1]";
}
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int D0 = batch_count / D1;
int D2 = extent.row();
int D3 = extent.column();
return {D0, D1, D2, D3};
}
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
return {s[0], s[3], s[2], s[1]};
}
};
template<int D1>
struct PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajorInverse<D1>>
: public PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>> {
static bool constexpr kBatched = true;
static int constexpr kRowFactor = D1;
static int constexpr kColumnFactor = 1;
static int constexpr kBatchFactor = 1;
using Base = PermuteInfo<cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>>;
using Layout = typename Base::Layout;
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int D0 = batch_count;
int D2 = extent.row() / D1;
int D3 = extent.column();
return {D0, D1, D2, D3};
}
};
template<int D1, int D2>
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>> {
static bool constexpr kBatched = false;
static int constexpr kRowFactor = D1;
static int constexpr kColumnFactor = D2;
static int constexpr kBatchFactor = 1;
using Layout = cutlass::layout::TensorNHWC;
static std::string name() {
return "Tensor4DPermute0213<" + std::to_string(D1) + "," + std::to_string(D2) + ">";
}
static std::string desc() {
return "normal GEMM permutation [0, 2, 1, 3]";
}
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int D0 = extent.row() / D1;
int D3 = extent.column() / D2;
return {D0, D1, D2, D3};
}
static Layout::TensorCoord permute(Layout::TensorCoord const &s) {
return {s[0], s[2], s[1], s[3]};
}
};
template<int D1, int D2>
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajorInverse<D1, D2>>
: public PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>> {
static bool constexpr kBatched = false;
static int constexpr kRowFactor = D2;
static int constexpr kColumnFactor = D1;
static int constexpr kBatchFactor = 1;
using Base = PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>>;
using Layout = typename Base::Layout;
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int D0 = extent.row() / D2;
int D3 = extent.column() / D1;
return {D0, D1, D2, D3};
}
};
template<int D1, int D2>
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213ColumnMajor<D1, D2>>
: public PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajor<D1, D2>> {
using Layout = cutlass::layout::TensorCWHN;
};
template<int D1, int D2>
struct PermuteInfo<cutlass::layout::Tensor4DPermute0213ColumnMajorInverse<D1, D2>>
: public PermuteInfo<cutlass::layout::Tensor4DPermute0213RowMajorInverse<D1, D2>> {
using Layout = cutlass::layout::TensorCWHN;
};
template<int T1, int T2, int T3>
struct PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>> {
static bool constexpr kBatched = false;
static int constexpr kRowFactor = T1;
static int constexpr kColumnFactor = T2 * T3;
static int constexpr kBatchFactor = 1;
using Layout = cutlass::layout::TensorNDHWC;
static std::string name() {
return "Tensor5DPermute20314<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">";
}
static std::string desc() {
return "normal GEMM permutation [2, 0, 3, 1, 4]";
}
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count)
{
int const T0 = extent.row() / T1;
int const T4 = extent.column() / (T2 * T3);
return {T0, T1, T2, T3, T4};
}
static Layout::TensorCoord permute(Layout::TensorCoord const &s)
{
return {s[2], s[0], s[3], s[1], s[4]};
}
};
template<int T1, int T2, int T3>
struct PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajorInverse<T1, T2, T3>>
: public PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>> {
static bool constexpr kBatched = false;
static int constexpr kRowFactor = T2;
static int constexpr kColumnFactor = T1 * T3;
static int constexpr kBatchFactor = 1;
using Base = PermuteInfo<cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>>;
using Layout = typename Base::Layout;
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int const T0 = extent.row() / T2;
int const T4 = extent.column() / (T1 * T3);
return {T0, T1, T2, T3, T4};
}
};
template<int T1, int T2, int T3>
struct PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>> {
static bool constexpr kBatched = false;
static int constexpr kRowFactor = T1;
static int constexpr kColumnFactor = T2 * T3;
static int constexpr kBatchFactor = 1;
using Layout = cutlass::layout::TensorCWHDN;
static std::string name() {
return "Tensor5DPermute02413<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">";
}
static std::string desc() {
return "normal GEMM permutation [0, 2, 4, 1, 3]";
}
using Coord = cutlass::Tensor5DCoord;
static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count)
{
int const T0 = extent.row() / T1;
int const T4 = extent.column() / (T2 * T3);
return {T0, T1, T2, T3, T4};
}
static Layout::TensorCoord permute(Layout::TensorCoord const &s)
{
return {s[0], s[2], s[4], s[1], s[3]};
}
};
template<int T1, int T2, int T3>
struct PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajorInverse<T1, T2, T3>>
: public PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>> {
static bool constexpr kBatched = false;
static int constexpr kRowFactor = T2;
static int constexpr kColumnFactor = T1 * T3;
static int constexpr kBatchFactor = 1;
using Base = PermuteInfo<cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>>;
using Layout = typename Base::Layout;
static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) {
int const T0 = extent.row() / T2;
int const T4 = extent.column() / (T1 * T3);
return {T0, T1, T2, T3, T4};
}
};

View File

@ -1,22 +1,4 @@
# CUTLASS Python Interface Examples
This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples:
* _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations
* [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel
# PyCUTLASS Examples
## Setting up the Python interface
Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API.
## Running examples
Each of the basic examples can be run as follows:
```shell
# Run the GEMM example
python gemm.py
# Run the Conv2d example
python conv2d.py
# Run the grouped GEMM example
python gemm_grouped.py
```
To run the customizable examples, refer to the README in the [customizable](customizable) directory.
This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface.
For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory.

View File

@ -33,15 +33,20 @@
Basic example of using the CUTLASS Python interface to run a 2d convolution
"""
import argparse
import torch
import numpy as np
import sys
print("This example is deprecated. Please see examples/python for examples of using "
"the CUTLASS Python interface.")
sys.exit(0)
import cutlass
import pycutlass
from pycutlass import *
from pycutlass.utils.device import device_cc
import argparse
import numpy as np
import torch
import cutlass_bindings
import cutlass.backend as pycutlass
from cutlass.backend import *
from cutlass.backend.utils.reference_model import Conv2dReferenceModule
from cutlass.backend.utils.device import device_cc
parser = argparse.ArgumentParser(
@ -76,11 +81,11 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
pycutlass.compiler.nvcc()
# Set up A, B, C and accumulator
A = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
B = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment)
C = TensorDescription(cutlass.float32, cutlass.TensorNHWC, alignment)
element_acc = cutlass.float32
element_epilogue = cutlass.float32
A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment)
B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment)
C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.TensorNHWC, alignment)
element_acc = cutlass_bindings.float32
element_epilogue = cutlass_bindings.float32
# Select instruction shape based on the Tensor Core instructions supported
# by the device on which we are running
@ -89,12 +94,14 @@ if cc == 70:
elif cc == 75:
instruction_shape = [16, 8, 8]
else:
# Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used)
cc = 80
instruction_shape = [16, 8, 16]
math_inst = MathInstruction(
instruction_shape,
A.element, B.element, element_acc,
cutlass.OpClass.TensorOp,
cutlass_bindings.OpClass.TensorOp,
MathOperation.multiply_add
)
@ -108,8 +115,8 @@ tile_description = TileDescription(
epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue)
operation = Conv2dOperation(
conv_kind=cutlass.conv.Operator.fprop,
iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized,
conv_kind=cutlass_bindings.conv.Operator.fprop,
iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized,
arch=cc, tile_description=tile_description,
A=A, B=B, C=C, stride_support=StrideSupport.Strided,
epilogue_functor=epilogue_functor
@ -125,20 +132,20 @@ pycutlass.compiler.add_module(operations)
# Randomly initialize tensors
problem_size = cutlass.conv.Conv2dProblemSize(
cutlass.Tensor4DCoord(args.n, args.h, args.c, args.w),
cutlass.Tensor4DCoord(args.k, args.r, args.s, args.c),
cutlass.Tensor4DCoord(0, 0, 0, 0), # Padding
cutlass.MatrixCoord(1, 1), # Strides
cutlass.MatrixCoord(1, 1), # Dilation
cutlass.conv.Mode.cross_correlation,
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(args.n, args.h, args.c, args.w),
cutlass_bindings.Tensor4DCoord(args.k, args.r, args.s, args.c),
cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), # Padding
cutlass_bindings.MatrixCoord(1, 1), # Strides
cutlass_bindings.MatrixCoord(1, 1), # Dilation
cutlass_bindings.conv.Mode.cross_correlation,
1, # Split k slices
1 # Groups
)
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size)
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size)
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size)
tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size)
tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size)
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size)
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5))

View File

@ -165,28 +165,3 @@ Example 7: GELU
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu
```
### Epilogue Visitor Tree
Example 1:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2:
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 3:
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 4:
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 5:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3
```
Example 6:
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3
```

View File

@ -29,13 +29,18 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
from pycutlass.conv2d_operation import *
from pycutlass.utils import reference_model
from pycutlass.utils.device import device_cc
import sys
print("This example is deprecated. Please see examples/python for examples of using "
"the CUTLASS Python interface.")
sys.exit(0)
import numpy as np
import cutlass.backend as pycutlass
from cutlass.backend import *
from cutlass.backend.utils.device import device_cc
from cutlass.backend.conv2d_operation import *
from cutlass.backend.utils.reference_model import Conv2dReferenceModule
import torch.nn.functional as F
import argparse
@ -62,7 +67,7 @@ parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
parser.add_argument('-op', "--opcode", default="Simt", type=str,
choices=["Simt", 'TensorOp'],
help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
@ -156,12 +161,12 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
element_a = getattr(cutlass_bindings, args.element_a)
element_b = getattr(cutlass_bindings, args.element_b)
element_c = getattr(cutlass_bindings, args.element_c)
element_acc = getattr(cutlass_bindings, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
opclass = getattr(cutlass_bindings.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
@ -173,9 +178,9 @@ tile_description = TileDescription(
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
layout_a = getattr(cutlass_bindings, args.layout_a)
layout_b = getattr(cutlass_bindings, args.layout_b)
layout_c = getattr(cutlass_bindings, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
@ -189,7 +194,7 @@ C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
element_epilogue = getattr(cutlass_bindings, args.element_epilogue)
if (args.activation_function == "identity"
or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)):
#
@ -200,10 +205,10 @@ else:
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, args.iterator_algorithm)
swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor)
stride_support = getattr(StrideSupport, args.stride_support)
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
conv_kind = getattr(cutlass_bindings.conv.Operator, args.conv_kind)
operation = Conv2dOperation(
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
@ -226,7 +231,7 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
@ -236,34 +241,34 @@ if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
pycutlass.compiler.add_module(operations)
problem_size = cutlass.conv.Conv2dProblemSize(
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
cutlass.conv.Mode.cross_correlation,
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
cutlass_bindings.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
cutlass_bindings.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
cutlass_bindings.MatrixCoord(args.stride[0], args.stride[1]),
cutlass_bindings.MatrixCoord(args.dilation[0], args.dilation[1]),
cutlass_bindings.conv.Mode.cross_correlation,
args.split_k_slices, 1
)
# User-provide inputs
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(
conv_kind, problem_size
)
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(
conv_kind, problem_size
)
if args.bias:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent(
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_extent(
conv_kind, problem_size
).at(3)
else:
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size(
tensor_D_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
@ -288,12 +293,12 @@ arguments = Conv2dArguments(
operation=operation, problem_size=problem_size, A=tensor_A,
B=tensor_B, C=tensor_C, D=tensor_D,
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)),
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
split_k_mode=getattr(cutlass_bindings.conv.SplitKMode, args.split_k_mode),
split_k_slices=problem_size.split_k_slices
)
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
reduction_arguments = ReductionArguments(
reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],

View File

@ -29,13 +29,18 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
from pycutlass.utils.device import device_cc
import cutlass
from bfloat16 import bfloat16
import sys
print("This example is deprecated. Please see examples/python for examples of using "
"the CUTLASS Python interface.")
sys.exit(0)
import numpy as np
import cutlass.backend as pycutlass
from cutlass.backend import *
from cutlass.backend.utils.device import device_cc
import cutlass_bindings
from bfloat16 import bfloat16
import argparse
@ -62,7 +67,7 @@ parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
parser.add_argument('-op', "--opcode", default="Simt", type=str,
choices=["Simt", 'TensorOp'],
help="This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
@ -100,8 +105,6 @@ parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
parser.add_argument("-epv", "--epilogue_visitor", default=None,
type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"],
@ -147,12 +150,12 @@ pycutlass.compiler.nvcc()
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
element_a = getattr(cutlass_bindings, args.element_a)
element_b = getattr(cutlass_bindings, args.element_b)
element_c = getattr(cutlass_bindings, args.element_c)
element_acc = getattr(cutlass_bindings, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
opclass = getattr(cutlass_bindings.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
@ -164,9 +167,9 @@ tile_description = TileDescription(
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
layout_a = getattr(cutlass_bindings, args.layout_a)
layout_b = getattr(cutlass_bindings, args.layout_b)
layout_c = getattr(cutlass_bindings, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
@ -180,7 +183,7 @@ C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
element_epilogue = getattr(cutlass_bindings, args.element_epilogue)
if (args.activation_function == "identity"
or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)):
#
@ -191,73 +194,12 @@ else:
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
visitor = args.epilogue_visitor is not None
if args.epilogue_visitor == "ColumnReduction":
class ColumnReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + beta * c
reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0])
return D, reduction
epilogue_functor = ColumnReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowReduction":
class RowReduction_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
alpha: 'scalar', beta: 'scalar'):
#
D = alpha * accum + tanh.numpy(beta * c)
reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1])
return D, reduction
epilogue_functor = RowReduction_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "RowBroadcast":
class RowBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'row', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = alpha * T
Z = relu.numpy(scale_T + beta * c)
return Z, T
epilogue_functor = RowBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
elif args.epilogue_visitor == "ColumnBroadcast":
class ColumnBroadcast_(EpilogueVisitTree):
def __call__(
self, accum: 'tensor', c: 'tensor',
vector: 'column', alpha: 'scalar', beta: 'scalar'):
#
T = accum + vector
scale_T = leaky_relu.numpy(alpha * T, 0.2)
Z = scale_T + beta * c
return Z, T
epilogue_functor = ColumnBroadcast_(
epilogue_functor, tile_description, math_inst.element_accumulator,
C.alignment, element_epilogue, C.element)
epilogue_functor.initialize()
else:
epilogue_functor = epilogue_functor
swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor)
operation = GemmOperationUniversal(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
visitor=visitor
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
if args.print_cuda:
@ -275,7 +217,7 @@ if args.gemm_mode == "GemmSplitKParallel":
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
epilogue_functor=epilogue_functor_reduction,
@ -287,7 +229,7 @@ pycutlass.compiler.add_module(operations)
# User-provide inputs
problem_size = cutlass.gemm.GemmCoord(
problem_size = cutlass_bindings.gemm.GemmCoord(
args.problem_size[0], args.problem_size[1], args.problem_size[2])
tensor_a_size = args.batch * problem_size.m() * problem_size.k()
@ -347,44 +289,13 @@ tensor_D = np.zeros(
shape=(args.batch * problem_size.m() * problem_size.n(),)
).astype(getattr(np, args.element_c))
if args.epilogue_visitor == "RowReduction":
cta_n = args.threadblock_shape[1]
num_cta_n = (problem_size.n() + cta_n - 1) // cta_n
reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnReduction":
cta_m = args.threadblock_shape[0]
num_cta_m = (problem_size.m() + cta_m - 1) // cta_m
reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c))
output_op = operation.epilogue_type(
D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "RowBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n()))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
elif args.epilogue_visitor == "ColumnBroadcast":
vector = np.ceil(
np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1))
).astype(getattr(np, args.element_c))
tensor_t = np.empty_like(tensor_D)
output_op = operation.epilogue_type(
c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()]
)
else:
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args))
arguments = GemmArguments(
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=output_op,
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
gemm_mode=getattr(cutlass_bindings.gemm.Mode, args.gemm_mode),
split_k_slices=args.split_k_slices, batch=args.batch
)
@ -411,38 +322,8 @@ reference = ReferenceModule(A, B, C)
tensor_D_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch)
if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten()
tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args))
if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, reduction_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
args.alpha, args.beta
)
tensor_D_ref = tensor_D_ref.flatten()
reduction_ref = reduction_ref.flatten()
assert np.allclose(reduction_ref, reduction, atol=1e-2)
elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]:
output_op.sync()
accum_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch)
tensor_D_ref, tensor_T_ref = epilogue_functor(
accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())),
tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())),
vector, args.alpha, args.beta)
tensor_D_ref = tensor_D_ref.flatten()
tensor_T_ref = tensor_T_ref.flatten()
assert np.array_equal(tensor_t, tensor_T_ref)
try:
assert np.array_equal(tensor_D, tensor_D_ref)
except:

View File

@ -29,12 +29,17 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
from pycutlass.utils.device import device_cc
import csv
import sys
print("This example is deprecated. Please see examples/python for examples of using "
"the CUTLASS Python interface.")
sys.exit(0)
import numpy as np
import cutlass.backend as pycutlass
from cutlass.backend import *
from cutlass.backend.utils.device import device_cc
import csv
import argparse
@ -61,7 +66,7 @@ parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
parser.add_argument('-op', "--opcode", default="Simt", type=str,
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
# tile description
@ -111,7 +116,7 @@ parser.add_argument("-pm", "--precompute_mode",
default="Device", type=str, choices=["Host", "Device"],
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
# arguments
parser.add_argument("-p", "--problem_size_dir", type=str,
parser.add_argument("-p", "--problem_size_dir", type=str, default="grouped_gemm_problem_size.csv",
help="path to the csv file contains the problem sizes")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
@ -139,12 +144,12 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
np.random.seed(0)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
element_a = getattr(cutlass_bindings, args.element_a)
element_b = getattr(cutlass_bindings, args.element_b)
element_c = getattr(cutlass_bindings, args.element_c)
element_acc = getattr(cutlass_bindings, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
opclass = getattr(cutlass_bindings.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
@ -156,9 +161,9 @@ tile_description = TileDescription(
math_inst
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
layout_a = getattr(cutlass_bindings, args.layout_a)
layout_b = getattr(cutlass_bindings, args.layout_b)
layout_c = getattr(cutlass_bindings, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
@ -172,7 +177,7 @@ C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
element_epilogue = getattr(cutlass_bindings, args.element_epilogue)
if args.activation_function == "identity":
epilogue_functor = getattr(pycutlass, args.epilogue_functor)(
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
@ -180,7 +185,7 @@ else:
epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")(
getattr(pycutlass, args.activation_function)(element_epilogue),
C.element, C.alignment, math_inst.element_accumulator, element_epilogue)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor)
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
operation = GemmOperationGrouped(
@ -203,7 +208,7 @@ with open(args.problem_size_dir) as csv_file:
reader = csv.reader(csv_file)
for row in reader:
problem_sizes.append(
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
cutlass_bindings.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
)
problem_count = len(problem_sizes)

View File

@ -33,14 +33,18 @@
Basic example of using the CUTLASS Python interface to run a GEMM
"""
import sys
print("This example is deprecated. Please see examples/python for examples of using "
"the CUTLASS Python interface.")
sys.exit(0)
import argparse
import numpy as np
import sys
import cutlass
import pycutlass
from pycutlass import *
from pycutlass.utils.device import device_cc
import cutlass_bindings
import cutlass.backend as pycutlass
from cutlass.backend import *
from cutlass.backend.utils.device import device_cc
parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'")
@ -72,11 +76,11 @@ pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
pycutlass.compiler.nvcc()
# Set up A, B, C and accumulator
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
element_acc = cutlass.float32
element_epilogue = cutlass.float32
A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment)
B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment)
C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment)
element_acc = cutlass_bindings.float32
element_epilogue = cutlass_bindings.float32
# Select instruction shape based on the Tensor Core instructions supported
# by the device on which we are running
@ -85,12 +89,14 @@ if cc == 70:
elif cc == 75:
instruction_shape = [16, 8, 8]
else:
# Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used)
cc = 80
instruction_shape = [16, 8, 16]
math_inst = MathInstruction(
instruction_shape,
A.element, B.element, element_acc,
cutlass.OpClass.TensorOp,
cutlass_bindings.OpClass.TensorOp,
MathOperation.multiply_add
)
@ -122,7 +128,7 @@ tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,)
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32)
tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32)
problem_size = cutlass.gemm.GemmCoord(args.m, args.n, args.k)
problem_size = cutlass_bindings.gemm.GemmCoord(args.m, args.n, args.k)
alpha = 1.
beta = 0.

View File

@ -33,14 +33,18 @@
Basic example of using the CUTLASS Python interface to run a grouped GEMM
"""
import sys
print("This example is deprecated. Please see examples/python for examples of using "
"the CUTLASS Python interface.")
sys.exit(0)
import argparse
import numpy as np
import sys
import cutlass
import pycutlass
from pycutlass import *
from pycutlass.utils.device import device_cc
import cutlass_bindings
import cutlass.backend as pycutlass
from cutlass.backend import *
from cutlass.backend.utils.device import device_cc
parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python")
@ -65,11 +69,11 @@ pycutlass.compiler.nvcc()
# Set up A, B, C and accumulator
alignment = 1
A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment)
B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment)
C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
element_acc = cutlass.float32
element_epilogue = cutlass.float32
A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment)
B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment)
C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment)
element_acc = cutlass_bindings.float32
element_epilogue = cutlass_bindings.float32
# Select instruction shape based on the Tensor Core instructions supported
# by the device on which we are running
@ -78,12 +82,14 @@ if cc == 70:
elif cc == 75:
instruction_shape = [16, 8, 8]
else:
# Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used)
cc = 80
instruction_shape = [16, 8, 16]
math_inst = MathInstruction(
instruction_shape,
A.element, B.element, element_acc,
cutlass.OpClass.TensorOp,
cutlass_bindings.OpClass.TensorOp,
MathOperation.multiply_add
)
@ -112,8 +118,8 @@ pycutlass.compiler.add_module(operations)
# Initialize tensors for each problem in the group
problem_sizes = [
cutlass.gemm.GemmCoord(128, 128, 64),
cutlass.gemm.GemmCoord(512, 256, 128)
cutlass_bindings.gemm.GemmCoord(128, 128, 64),
cutlass_bindings.gemm.GemmCoord(512, 256, 128)
]
problem_count = len(problem_sizes)

View File

@ -37,8 +37,20 @@ cutlass_example_add_executable(
fused_multihead_attention_variable_seqlen.cu
)
cutlass_example_add_executable(
41_fused_multi_head_attention_backward
fused_multi_head_attention_backward.cu
DISABLE_TESTS ON
)
add_custom_target(41_fused_multi_head_attention
DEPENDS 41_fused_multi_head_attention_fixed_seqlen
41_fused_multi_head_attention_variable_seqlen
41_fused_multi_head_attention_backward
)
add_test(
NAME ctest_examples_41_fmha_backward_python
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/fmha_backward_test.py $<TARGET_FILE:41_fused_multi_head_attention_backward>
)

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -50,6 +50,7 @@
#include "fmha_grouped.h"
#include "gemm_kernel_utils.h"
#include "gemm/custom_mma.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
@ -70,7 +71,7 @@ template <
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration,
int kMaxK = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
>
struct DefaultFMHAGrouped {
@ -85,6 +86,8 @@ struct DefaultFMHAGrouped {
using ArchTag = ArchTag_;
static bool const kIsAligned = isAligned_;
static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock;
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
static int const kWarpSize = 32;
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
@ -145,14 +148,20 @@ struct DefaultFMHAGrouped {
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
using Mma = typename cutlass::platform::conditional<
kSingleValueIteration,
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
DefaultThreadblockMma>::type;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
ElementAccumulator,
@ -232,14 +241,24 @@ struct DefaultFMHAGrouped {
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
kSplitKSerial,
Operator>;
using WarpIteratorA = typename cutlass::gemm::threadblock::
DefaultWarpIteratorAFromSharedMemory<
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
typename DefaultGemm::Mma::Policy>::WarpIterator;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage,
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
WarpIteratorA,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
@ -256,10 +275,6 @@ struct DefaultFMHAGrouped {
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
/// Define the kernel in terms of the default kernel

View File

@ -0,0 +1,200 @@
import argparse
import torch
import sys
import os
from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME
import math
parser = argparse.ArgumentParser()
parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable")
args = parser.parse_args()
torch.manual_seed(0)
dtype = torch.float16
B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128
causal = True
repeat_count = 100
ATOL = {
torch.float: 5e-4,
torch.half: 9.5e-2,
torch.bfloat16: 7e-1,
}[dtype]
RTOL = {
torch.float: 1e-4,
torch.half: 2e-2,
torch.bfloat16: 1e-1,
}[dtype]
assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ"
fmha_bw_binary = args.example_exe
if not os.path.isfile(fmha_bw_binary):
print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""")
sys.exit(1)
def create_lower_triangular_mask():
return torch.triu(torch.full( # type: ignore
[1, Mq, Mkv],
dtype=dtype,
fill_value=float("-inf"),
), diagonal=1)
def ref_mha_bmk(q, k, v, mask):
# Multi-head attention with inputs/outputs in BMK format
q = q.float()
k = k.float()
v = v.float()
q = q * (1 / q.shape[-1] ** 0.5)
attn = q @ k.transpose(-2, -1)
if mask is not None:
attn += mask
attn_max = attn.max(-1, True).values
attn_norm = (attn - attn_max).exp().sum(-1, True)
attn = attn.softmax(-1)
lse = attn_max + attn_norm.log()
lse = lse.squeeze(2)
return attn @ v, lse
def bmhk2bmk(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)
def ref_mha_bmhk(q, k, v, mask):
# Multi-head attention with inputs/outputs in BMHK format
assert q.ndim == 4
out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask)
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]])
def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta):
lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension
delta = delta.reshape([-1, delta.shape[-1], 1])
# bmhk -> bmk
q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)]
attn_T = k @ q.transpose(-2, -1)
if mask is not None:
attn_T += mask.transpose(-2, -1)
attn_T = attn_T * (1 / q.shape[-1] ** 0.5)
attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]])
attn_T = attn_T.exp()
grad_v = attn_T @ grad_out
dov = grad_out @ v.transpose(-2, -1)
tmp = (dov - delta) * attn_T.transpose(-2, -1)
tmp = tmp / (q.shape[-1] ** 0.5)
grad_q = tmp @ k
grad_k = tmp.transpose(-2, -1) @ q
return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]]
print("initializing tensors...")
query = torch.randn([B, Mq, H, K], dtype=dtype)
key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype)
value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype)
mask = create_lower_triangular_mask() if causal else None
# let PyTorch compute gradients
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
print("computing fw...")
out, lse = ref_mha_bmhk(query, key, value, mask=mask)
out = out.to(dtype).contiguous()
grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype)
print("computing bw with autograd...")
out.backward(grad_out)
scale = (1 / query.shape[-1] ** 0.5)
# Additional data needed by the kernel
delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous()
pad_amount = (32 - (lse.shape[2] % 32)) % 32
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
print("computing bw with reference implem...")
gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta)
with PipedSubprocess(fmha_bw_binary) as bw_kernel:
# Send kernel arguments
bw_kernel.write(
TORCH_DTYPE_NAME[query.dtype],
"scale", scale,
"head_dim", K,
"head_dim_value", Kv,
"num_queries", Mq,
"num_keys", Mkv,
"num_heads", H,
"custom_mask_type", (1 if causal else 0),
"num_batches", B,
"repeat_count", repeat_count,
"num_splits_key", (Mkv // 128),
)
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"])
bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"])
bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"])
bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"])
bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"])
if bw_kernel.read() != "OK":
print("Got unexpected output")
print(bw_kernel.subp.communicate()[0])
sys.exit(0)
# Read kernel output
gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float()
gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float()
gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float()
runtime_ms = float(bw_kernel.readNamed("runtime_ms"))
float_ops = B * H * sum([
# att = Q @ K.transpose
Mq * Mkv * K * 2,
# att @ dO
Mkv * Mq * Kv * 2,
# dov = dO @ V
Mq * Kv * Mkv * 2,
# dov @ K
Mq * K * Mkv * 2,
# dov @ Q
Mq * K * Mkv * 2,
])
if causal:
float_ops //= 2
print(f"""
Fused multi-head attention - backward
batch_size={B}
num_queries={Mq}
num_keys={Mkv}
num_heads={H}
head_dim={K}
head_dim_value={Kv}
Correctness:
grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()})
grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()})
grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()})
(atol={ATOL} / rtol={RTOL})
Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops)
""")
assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"

View File

@ -147,6 +147,9 @@ public:
static int const kThreadsPerWarp = 32;
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
using ProblemVisitor = FMHAGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
@ -369,13 +372,16 @@ public:
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
addition_storage;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
};
union {
@ -397,7 +403,7 @@ public:
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
@ -490,6 +496,7 @@ public:
auto& s_prime = shared_storage.s_prime;
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
ProblemVisitor problem_visitor(
params.problem_visitor,
@ -512,6 +519,7 @@ public:
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = ElementAccumulator(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
@ -568,7 +576,7 @@ public:
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.mm1,
iterator_V,
thread_id(),
problem_size_1_k);
@ -623,6 +631,8 @@ public:
if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
@ -649,30 +659,48 @@ public:
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : params.scale);
}));
}));
// DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
// DISPATCH_BOOL(
// num_keys - iter_key_start >= kKeysPerBlock,
// kFullColumns,
// ([&] {
// // Update `mi` from accum stored in registers
// // Also does accum[i] <- exp(accum[i] - mi)
// iterative_softmax<
// typename MM0::Mma::Operator::IteratorC,
// kFullColumns,
// kIsFirst>(
// accum_o,
// accum,
// mi,
// m_prime,
// s_prime,
// lane_id(),
// thread_id(),
// warp_id(),
// num_keys - iter_key_start,
// iteratorC_tile_offset,
// kSupportsBias ? 1.0f : params.scale);
// }));
// }));
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
accum_o,
accum,
mi,
m_prime,
s_prime,
out_rescale,
shared_storage.addition_storage,
lane_id(),
thread_id(),
warp_id(),
num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : params.scale);
// Output results to shared-memory
int warp_idx_mn_0 = warp_id() %
@ -717,12 +745,14 @@ public:
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
// operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
(int)lane_id());
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
@ -737,6 +767,7 @@ public:
}
if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
@ -787,7 +818,7 @@ public:
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
@ -836,34 +867,37 @@ public:
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}
// Next tile
problem_visitor.advance(gridDim.x);
__syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
}
}
template <
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
template <typename WarpIteratorC>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
addition_storage,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
@ -884,12 +918,11 @@ public:
kThreadsPerWarp>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
@ -903,46 +936,64 @@ public:
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
if (accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
atomicMaxFloat(&mi[accum_m], max);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
// Doing this `exp` is quite expensive. Let's
// split it across the warps
bool restore_mi_to_minus_inf = false;
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
auto m_prime_id = m_prime[id];
auto mi_id = mi[id];
bool changed = m_prime_id < mi_id; // `false` if both are -inf
if (changed) {
auto m_prime_exp = exp2f(m_prime_id - mi_id);
out_rescale[id] = m_prime_exp;
s_prime[id] *= m_prime_exp;
} else {
// Only when bias is enabled, it's possible that all the first values
// of attention are masked to `-inf`. In that case we want to avoid
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
if (kSupportsBias &&
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
restore_mi_to_minus_inf = true;
mi[id] = 0.0f;
}
out_rescale[id] = 1.0f;
}
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
if (kKeepOutputInRF && !is_first) {
accum_t line_rescale;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag_o[idx] = frag_o[idx] * line_rescale;
},
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m) { mi_row = mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
@ -954,10 +1005,31 @@ public:
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
// NOTE: we could atomically add `total_row` to `s_prime`, but
// it's faster (and deterministic) to avoid atomics here
addition_storage
[accum_m + kQueriesPerBlock * tile_offset.column()] =
total_row;
}
});
}
__syncthreads();
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
accum_t total_row = s_prime[id];
if (restore_mi_to_minus_inf) {
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
} else {
m_prime[id] = mi[id];
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
s_prime[id] = total_row;
}
}
};

View File

@ -0,0 +1,298 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <vector>
#include <iostream>
#include <fstream>
#include "kernel_backward.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
using Arch = cutlass::arch::Sm80;
static constexpr int kMaxK = 128;
template <typename ArchTag, typename Element, int kMaxK>
struct DefaultKernel {
// Some heuristics to select the best kernel (tested on Sm60, Sm70, Sm80)
// NOTE: Requires quite a lot of shmem for Sm80+,
// so might require tweaking those manually for Sm86/Sm89
static constexpr bool kSupports64x128 =
ArchTag::kMinComputeCapability >= 80 ||
(ArchTag::kMinComputeCapability >= 70 &&
cutlass::sizeof_bits<Element>::value <= 16);
static constexpr int kBlockSizeI = kSupports64x128 && kMaxK > 64 ? 128 : 64;
static constexpr bool kIsHalf = cutlass::sizeof_bits<Element>::value <= 16;
static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
static constexpr bool kPreload = kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF;
static constexpr int kBlockSizeJ = kPreload && kMaxK > 64 ? 128 : 64;
using Kernel = AttentionBackwardKernel<
Arch,
Element,
true, // kIsAligned_
false, // kApplyDropout_
kPreload, // kPreload_
kBlockSizeI, // kBlockSizeI_,
kBlockSizeJ, // kBlockSizeJ_,
kMaxK, // kMaxK
false, // kKeysQueriesAlignedToBlockSize
true // kEnableSplitKeys
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace {
template <typename T> struct TypeName;
template <> struct TypeName<float> { static constexpr const char* Name = "f32"; };
template <> struct TypeName<cutlass::half_t> { static constexpr const char* Name = "f16"; };
template <> struct TypeName<cutlass::bfloat16_t> { static constexpr const char* Name = "b16"; };
void readExpect(std::string const& expected) {
std::string read;
std::cin >> read;
if (read != expected) {
std::cerr << "FATAL: Read '" << read << "' but expected '" << expected << "'" << std::endl;
std::exit(1);
}
}
/// Helpers to read from stdin
template <typename Element>
cutlass::HostTensor<Element, cutlass::layout::RowMajor> readTensorOnDevice(std::string const& expectedName) {
readExpect("tensor_begin");
readExpect(std::string(TypeName<Element>::Name) + ":" + expectedName);
uint64_t len = 0;
std::cin >> len;
readExpect("file");
std::string filename;
std::cin >> filename;
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor({int64_t(1), int64_t(len / sizeof(Element))});
uint8_t* data = (uint8_t*)tensor.host_data();
std::fstream myFile(filename, std::ios::in | std::ios::binary );
myFile.read((char*)data, len);
readExpect("tensor_end");
tensor.sync_device();
return tensor;
}
int64_t readInt64(std::string const& expectedName) {
readExpect(expectedName);
int64_t s = 0;
std::cin >> s;
return s;
}
float readFloat(std::string const& expectedName) {
readExpect(expectedName);
float s = 0;
std::cin >> s;
return s;
}
// Writing
template <typename Element>
void writeTensor(std::string const& name, cutlass::HostTensor<Element, cutlass::layout::RowMajor>& tensor) {
tensor.sync_host(); // device->host
size_t u8len = tensor.size() * sizeof(Element);
// Python is expected to provide a file name to write to
readExpect("tmpfile");
std::string tmpfile;
std::cin >> tmpfile;
uint8_t* data = (uint8_t*)tensor.host_data();
std::fstream myFile(tmpfile, std::ios::out | std::ios::binary );
myFile.write((char*)data, u8len);
myFile.close();
std::cout << "tensor_begin " << TypeName<Element>::Name << ":" << name << " ";
std::cout << u8len << " file " << tmpfile << " tensor_end" << std::endl;
}
void writeInt64(std::string const& name, int64_t value) {
std::cout << name << " " << value << std::endl;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element>
int runKernel() {
using Kernel = typename DefaultKernel<Arch, Element, kMaxK>::Kernel;
#define READ_I64(NAME) p.NAME = (decltype(p.NAME))readInt64(#NAME)
#define READ_TENSOR_AND_STRIDES_BMH(DT, NAME, NAME_XS) \
auto storage##NAME = readTensorOnDevice<DT>(#NAME); \
p.NAME##_ptr = storage##NAME.device_data(); \
READ_I64(NAME_XS##_strideB); \
READ_I64(NAME_XS##_strideM); \
READ_I64(NAME_XS##_strideH);
#define CUDA_CHECK(FN) { \
auto cudaError = FN; \
if (cudaError != cudaSuccess) { \
std::cerr << "FATAL: " #FN " failed: " << cudaGetErrorString(cudaError) << std::endl; \
return -1; \
} \
}
typename Kernel::Params p;
p.scale = readFloat("scale");
READ_I64(head_dim);
READ_I64(head_dim_value);
READ_I64(num_queries);
READ_I64(num_keys);
READ_I64(num_heads);
READ_I64(custom_mask_type);
READ_I64(num_batches);
int64_t repeat_count = readInt64("repeat_count");
READ_I64(num_splits_key);
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);
READ_TENSOR_AND_STRIDES_BMH(Element, value, v);
auto lse = readTensorOnDevice<typename Kernel::lse_scalar_t>("logsumexp");
p.logsumexp_ptr = lse.device_data();
p.lse_strideB = readInt64("lse_strideB");
p.lse_strideH = readInt64("lse_strideH");
// output
auto stOutput = readTensorOnDevice<Element>("output");
p.output_ptr = stOutput.device_data();
READ_I64(o_strideB);
auto o_strideM = readInt64("o_strideM");
if (o_strideM != p.o_strideM()) {
std::cerr << "Invalid `o_strideM`: " << o_strideM << " - expected " << p.o_strideM();
return 2;
}
READ_I64(o_strideH);
READ_TENSOR_AND_STRIDES_BMH(Element, grad_output, gO);
auto stDelta = readTensorOnDevice<typename Kernel::accum_t>("delta");
p.delta_ptr = stDelta.device_data();
READ_I64(delta_strideB);
READ_I64(delta_strideH);
// Allocate workspace
if (p.workspace_size()) {
cudaMalloc(&p.workspace, p.workspace_size());
}
// Allocate outputs in BMHK format
p.gQKV_strideM_multiplier = 1;
p.gQ_strideH = p.head_dim;
p.gQ_strideB = p.gQ_strideM() * p.num_queries;
p.gK_strideH = p.head_dim;
p.gK_strideB = p.gK_strideM() * p.num_keys;
p.gV_strideH = p.head_dim_value;
p.gV_strideB = p.gV_strideM() * p.num_keys;
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gQ({int64_t(1), p.gQ_strideB * p.num_batches});
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gK({int64_t(1), p.gK_strideB * p.num_batches});
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gV({int64_t(1), p.gV_strideB * p.num_batches});
p.grad_query_ptr = gQ.device_data();
p.grad_key_ptr = gK.device_data();
p.grad_value_ptr = gV.device_data();
if (!Kernel::check_supported(p)) {
std::cerr << "FATAL: Kernel does not support these inputs" << std::endl;
return 2;
}
// Run kernel
cudaDeviceSynchronize();
auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
CUDA_CHECK(cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes)));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
// Write outputs
std::cout << "OK ";
writeTensor("grad_query", gQ);
writeInt64("gQ_strideB", p.gQ_strideB);
writeInt64("gQ_strideM", p.gQ_strideM());
writeInt64("gQ_strideH", p.gQ_strideH);
writeTensor("grad_key", gK);
writeInt64("gK_strideB", p.gK_strideB);
writeInt64("gK_strideM", p.gK_strideM());
writeInt64("gK_strideH", p.gK_strideH);
writeTensor("grad_value", gV);
writeInt64("gV_strideB", p.gV_strideB);
writeInt64("gV_strideM", p.gV_strideM());
writeInt64("gV_strideH", p.gV_strideH);
// Timing
cudaEvent_t events[2];
for (auto & event : events) {
CUDA_CHECK(cudaEventCreate(&event));
}
CUDA_CHECK(cudaEventRecord(events[0]));
for (int i = 0; i < repeat_count; ++i) {
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
}
CUDA_CHECK(cudaEventRecord(events[1]));
CUDA_CHECK(cudaEventSynchronize(events[1]));
// Measure elapsed runtime
float runtime_ms = 0;
CUDA_CHECK(cudaEventElapsedTime(&runtime_ms, events[0], events[1]));
std::cout << "runtime_ms " << runtime_ms / float(repeat_count) << std::endl;
return 0;
}
int main() {
std::ios_base::sync_with_stdio(false);
std::string dtype;
std::cin >> dtype;
std::cerr << "Running kernel with dtype: " << dtype << std::endl;
if (dtype == "f16") {
return runKernel<cutlass::half_t>();
} else if (dtype == "b16") {
return runKernel<cutlass::bfloat16_t>();
} else if (dtype == "f32") {
return runKernel<float>();
} else {
std::cerr << "FATAL: Unknown dtype: " << dtype << std::endl;
return 3;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -999,7 +999,7 @@ public:
template <
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration
int kMaxK
>
int run_attention(Options& options) {
using Attention = AttentionKernel<
@ -1008,7 +1008,7 @@ int run_attention(Options& options) {
true, // Memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
false, // Supports dropout
false // Supports bias
>;
@ -1094,15 +1094,16 @@ int main(int argc, char const **args) {
if (options.head_size_v > 64) {
static int const kQueriesPerBlock = 32;
static int const kKeysPerBlock = 128;
if (options.head_size_v <= kKeysPerBlock) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
if (options.head_size_v <= 128) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
} else {
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
}
} else {
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
static int const kQueriesPerBlock = 64;
static int const kKeysPerBlock = 64;
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
}
}

View File

@ -1061,7 +1061,7 @@ public:
template <
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration,
int kMaxK,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_
>
int run_grouped(Options& options) {
@ -1071,7 +1071,7 @@ int run_grouped(Options& options) {
true, // Memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
GroupScheduleMode_
>::FMHAKernel;
@ -1098,18 +1098,18 @@ int run_grouped(Options& options) {
template <
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration
int kMaxK
>
int run_attention(Options& options) {
if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) {
return run_grouped<kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options);
} else {
return run_grouped<kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options);
}
}
@ -1180,14 +1180,15 @@ int main(int argc, char const **args) {
static int const kQueriesPerBlock = 32;
static int const kKeysPerBlock = 128;
if (options.head_size_v <= kKeysPerBlock) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
} else {
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
}
} else {
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
static int const kQueriesPerBlock = 64;
static int const kKeysPerBlock = 64;
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
}
}

View File

@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};

View File

@ -310,7 +310,8 @@ class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tightest latency requirement).
// issuing shared memory loads (which have the tightest latency
// requirement).
//
// Mainloop

View File

@ -30,7 +30,8 @@
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
\brief Tools and utils to store a GEMM output in shmem, and to use that
output as operandA for another GEMM back-to-back
*/
#pragma once
@ -55,6 +56,7 @@
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
#include "../gemm/mma_accum_lambda_iterator.h"
#include "../gemm_kernel_utils.h"
#include "../iterators/default_warp_iterator_from_smem.h"
#include "../iterators/make_residual_last.h"
#include "../iterators/transpose_warp_iterator.h"
#include "../iterators/warp_iterator_from_smem.h"
@ -128,18 +130,22 @@ class AccumulatorSharedStorage {
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
// Maximum value for K
int kMaxK,
// Maximum K dimension - also the dimension of the shared-memory
// holding `OperandA`
int kMaxK_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Layout in shared-memory of operand A
typename SmemLayoutA,
/// Used for partial specialization
typename Enable = bool>
class MmaBaseFromSharedMemory {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
static constexpr int kMaxK = kMaxK_;
///< Policy describing tuning details
using Policy = Policy_;
@ -175,8 +181,7 @@ class MmaBaseFromSharedMemory {
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
@ -240,14 +245,14 @@ class MmaBaseFromSharedMemory {
CUTLASS_DEVICE
MmaBaseFromSharedMemory(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage,
TensorRefB& b_tile,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
: warp_tile_iterator_B_(b_tile, lane_idx) {}
};
namespace {
@ -264,9 +269,8 @@ class NoOpWarpIteratorScale {
// in pipelined+multistage MMA implementations we keep an array of fragments.
// if we aren't using scaling we don't want to waste registers on fragments
// of scale elements, so ideally this would be sized 0.
// using size 1 is kind of a hack to get around arrays of zero-sized objects
// not being allowed. the compiler is probably smart enough to wipe it out
// anyways.
// Since arrays of zero-sized objects are not allowed, using size as 1.
// The compiler will most likely wipe it out anyways.
using Fragment = cutlass::Array<char, 1>;
CUTLASS_HOST_DEVICE
@ -334,14 +338,13 @@ template <
typename Shape_,
// BEGIN smem
/// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA,
typename WarpIteratorA_,
/// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
// END smem
/// Max GEMM problem size in K dimension
int MaxK,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
@ -364,21 +367,24 @@ template <
typename Enable = bool>
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
Shape_,
AccumulatorSharedStorage::Shape::kN,
MaxK,
Policy_,
2> {
2,
typename WarpIteratorA_::Layout> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<
Shape_,
AccumulatorSharedStorage::Shape::kN,
MaxK,
Policy_,
2>;
2,
typename WarpIteratorA_::Layout>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
static constexpr bool ScaleOperandA = ScaleOperandA_;
using WarpIteratorA = WarpIteratorA_;
///< loads fragments of A_scale from shared memory if operand A scaling is
///< enabled. otherwise no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
@ -455,19 +461,17 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
// warp iterator over A tile held in shared memory
WarpIteratorA warp_iter_a,
// warp iterator over A_scale tile held in shared memory
WarpIteratorAScale warp_iter_a_scale,
typename Base::TensorRefA a, // Operand A in shared memory
typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
typename Base::TensorRefB
b_staging, // staging memory for loading tiles of B
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(warp_iter_a),
warp_tile_iterator_A_scale_(warp_iter_a_scale),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
: Base(b_staging, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(a, lane_idx),
warp_tile_iterator_A_scale_(a_scale, lane_idx),
smem_iterator_B_(b_staging, thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
@ -490,17 +494,14 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
/// Construct from tensor references
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage,
typename Base::TensorRefA a, ///< Operand A in shared memory
typename Base::TensorRefB b_staging, ///< staging memory for loading B
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx, ///< ID of each thread within a warp
int problem_size_0_n)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
int lane_idx) ///< ID of each thread within a warp
: Base(b_staging, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(a, lane_idx),
smem_iterator_B_(b_staging, thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
@ -532,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
int thread_idx,
int problem_size_0_n) {}
CUTLASS_DEVICE
static void drain_cp_asyncs() {}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
@ -600,7 +604,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tightest latency requirement).
// issuing shared memory loads (which have the tightest latency
// requirement).
//
// Mainloop
@ -621,8 +626,10 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
bool hasNext = true;
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_B_.store(transform_B(tb_frag_B));
if (gemm_k_iterations > 1) {
// Write fragments to shared memory
this->smem_iterator_B_.store(transform_B(tb_frag_B));
}
__syncthreads();
@ -696,8 +703,6 @@ template <
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
@ -718,11 +723,20 @@ template <
int kMaxK_,
/// Used for partial specialization
typename Enable = bool>
class MmaMultistageFromSharedMemory
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
using Base = MmaBaseFromSharedMemory<
Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape1 = Shape1_;
@ -826,20 +840,16 @@ class MmaMultistageFromSharedMemory
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
// warp level iterator over operand A tile kept in shared memory
WarpIteratorA1 warp_tile_iterator_A1,
// warp level iterator over operand A elementwise scale tile kept in
// shared memory.
WarpIteratorAScale warp_tile_iterator_A1_scale,
typename Base::TensorRefA a,
typename Base::TensorRefA a_scale,
typename Base::TensorRefB b_tile,
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(warp_tile_iterator_A1),
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
: Base(b_tile, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(a, lane_idx),
warp_tile_iterator_A1_scale_(a_scale, lane_idx),
smem_iterator_B1_(b_tile, thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
@ -864,23 +874,17 @@ class MmaMultistageFromSharedMemory
/// Construct from tensor references
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage,
typename Base::TensorRefA a,
typename Base::TensorRefB b_tile,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx,
///< GEMM0 N is used for accumulator extent
int problem_size_0_n)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(
accumulator_shared_storage.accum_ref(),
lane_idx),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
int lane_idx)
: Base(b_tile, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(a, lane_idx),
smem_iterator_B1_(b_tile, thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
@ -920,6 +924,15 @@ class MmaMultistageFromSharedMemory
smem_iterator_B1);
}
CUTLASS_DEVICE
static void drain_cp_asyncs() {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
CUTLASS_DEVICE
void copy_tiles_and_advance_1(
IteratorB1& iterator_B1,
@ -1254,100 +1267,11 @@ class MmaMultistageFromSharedMemory
}
};
template <
typename WarpShape,
typename InstructionShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpShape = cutlass::MatrixShape<32, 32>;
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element>;
};
// TensorOp - Ampere f32
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajor,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
OpDelta::kRow,
kWarpSize>;
};
// TensorOp - Volta
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 16, 4>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
// WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
cutlass::MatrixShape<16, 4>,
OpDelta::kRow,
kWarpSize>;
};
// Simt
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<1, 1, 1>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
static constexpr auto kWarpSize = 32;
// We just use the same iterator, as we reproduced the same shared-memory
// schema. Just modify it to handle non-complete tiles.
using WarpIterator = RegularWarpIterator;
};
// Converts a "regular" Mma into their counterpart from shared memory
template <
typename Mma_,
typename AccumulatorSharedStorage,
int kMaxK,
typename WarpIteratorA_,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
@ -1365,6 +1289,7 @@ template <
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
typename WarpIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
@ -1382,7 +1307,8 @@ template <
typename TransformA_,
/// Transformation applied to B operand
typename TransformB_,
typename AccumulatorSharedStorage_,
// Max MMA problem size K
int kMaxK,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
@ -1399,12 +1325,10 @@ struct DefaultMmaFromSharedMemory<
Policy_,
TransformA_,
TransformB_>,
AccumulatorSharedStorage_,
kMaxK,
WarpIteratorA_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
using RegularMma = MmaPipelined<
Shape_,
IteratorA_,
@ -1422,11 +1346,7 @@ struct DefaultMmaFromSharedMemory<
using ArchMmaOperator = typename Policy_::Operator;
static constexpr bool kIsTransposedA = false;
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using WarpIteratorA = WarpIteratorA_;
using IteratorB =
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
IteratorB_>::Iterator;
@ -1435,7 +1355,7 @@ struct DefaultMmaFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
kMaxK,
IteratorB,
SmemIteratorB_,
ElementC_,
@ -1453,6 +1373,7 @@ template <
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
typename WarpIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
@ -1474,7 +1395,7 @@ template <
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
typename AccumulatorSharedStorage_,
int kMaxK,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
@ -1493,11 +1414,10 @@ struct DefaultMmaFromSharedMemory<
Policy_,
Stages,
SharedMemoryClear>,
AccumulatorSharedStorage_,
kMaxK,
WarpIteratorA_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
using RegularMma = MmaMultistage<
Shape_,
IteratorA_,
@ -1514,11 +1434,6 @@ struct DefaultMmaFromSharedMemory<
using WarpShape = typename Policy_::Operator::Shape;
using InstructionShape = typename Policy_::Operator::InstructionShape;
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
static constexpr bool kIsTransposedA =
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
@ -1527,9 +1442,6 @@ struct DefaultMmaFromSharedMemory<
typename WarpIteratorTranspose::Iterator,
WarpIteratorA_>::type;
static int constexpr kMaxK = kIsTransposedA
? AccumulatorSharedStorage_::Shape::kM
: AccumulatorSharedStorage_::Shape::kN;
// Reduce the number of stages if we don't need that many
static int constexpr kStagesMax =
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
@ -1543,7 +1455,6 @@ struct DefaultMmaFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
IteratorB,
SmemIteratorB_,
RegularMma::kCacheOpB,
@ -1751,27 +1662,17 @@ struct B2bGemm<
using FragmentC = IteratorC::Fragment;
using lse_scalar_t = float;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
WarpShape,
cutlass::gemm::GemmShape<32, 32, 4>,
scalar_t,
SmemAccumulatorLayout>;
// // Storage in shared-memory for Q.Kt
// Storage in shared-memory for Q.Kt
using SmemAccumulatorLayout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape,
scalar_t,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<
16,
32>, // typename SmemIteratorD0::TensorLayout,
SmemAccumulatorLayout,
cutlass::MatrixShape<0, 0> // Padding
>;
using OutputLayout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
using TensorRef = cutlass::TensorRef<scalar_t, OutputLayout>;
using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>;
using Policy = typename IteratorC::Policy;
using Element = accum_t;
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields

View File

@ -115,10 +115,10 @@
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << #COND " failed\n"; \
return false; \
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << "'" #COND "' failed: " << ERR << "\n"; \
return false; \
}
#endif
@ -228,8 +228,17 @@ struct call_conditional<false, TA, TB> {
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
template <typename T>
CUTLASS_DEVICE T warp_uniform(T value) {
struct {
union {
T value;
uint32_t asInt;
};
} p;
p.value = value;
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
return p.value;
}
template <typename T>

View File

@ -0,0 +1,143 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Instanciates the right WarpIterator to read from shared memory
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
data dumped with `B2bGemm::accumToSmem`.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
#include "cutlass/platform/platform.h"
#include "warp_iterator_from_smem.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
typename WarpShape,
typename InstructionShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy, int kInstrK>
struct DefaultWarpIteratorAFromSharedMemory<
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, kInstrK>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpShape = cutlass::MatrixShape<32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>;
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>;
};
// TensorOp - Ampere f32
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajor,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
OpDelta::kRow,
kWarpSize>;
};
// TensorOp - Volta
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 16, 4>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
// WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
cutlass::MatrixShape<16, 4>,
OpDelta::kRow,
kWarpSize>;
};
// Simt
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<1, 1, 1>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
static constexpr auto kWarpSize = 32;
// We just use the same iterator, as we reproduced the same shared-memory
// schema. Just modify it to handle non-complete tiles.
using WarpIterator = RegularWarpIterator;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -175,7 +175,7 @@ class PredicatedTileAccessIteratorResidualLast<
Mask residual_tile_mask;
/// Parameters object with precomputed internal state
Params const& params_;
Params params_;
/// Internal pointer to first access of tile
BytePointer pointer_;
@ -1018,7 +1018,7 @@ class PredicatedTileAccessIteratorResidualLast<
//
/// Parameters object with precomputed internal state
Params const& params_;
Params params_;
/// Internal pointer to first access of tile
BytePointer pointer_;

View File

@ -44,10 +44,12 @@ template <
cutlass::gemm::Operand Operand,
/// Data type of A elements
typename Element,
typename InstructionShape,
bool kTranspose>
struct TransposeWarpIterator<
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
using Iterator =
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
cutlass::gemm::warp::
WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
using Iterator = cutlass::gemm::warp::
WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
static bool constexpr kSupportsTranspose = true;
};

View File

@ -56,6 +56,7 @@ template <
Operand Operand_,
/// Data type of A elements
typename Element_,
typename InstructionShape_,
bool kTranspose = false>
class WarpIteratorFromSmem {
public:
@ -64,6 +65,9 @@ class WarpIteratorFromSmem {
/// Operand tag
static Operand const kOperand = Operand_;
static_assert(
kOperand == Operand::kA,
"No support for OperandB at the moment");
/// Basic check
static_assert(
@ -78,7 +82,11 @@ class WarpIteratorFromSmem {
using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = cutlass::MatrixShape<16, 8>;
using InstructionShape = InstructionShape_;
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
static_assert(
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
"Only supports 16x8x8 / 16x8x16");
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
@ -133,7 +141,9 @@ class WarpIteratorFromSmem {
: InstructionShape::kRow);
static int constexpr kAccessesInner =
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
// Number of 32bits tiles to load per `ldmatrix`
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
private:
/// Underlying tensor reference
@ -153,38 +163,28 @@ class WarpIteratorFromSmem {
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
: ref_(ref), iterations_(0) {
// See also:
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
int ldsm_vec_num = (lane_id >> 3);
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id % 8, 0);
static_assert(
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
"");
CUTLASS_PRAGMA_UNROLL
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
++inst_m_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
++access_m_idx) {
int access_idx = access_m_idx +
kTilesPerInstruction *
(inner_idx + kAccessesInner * inst_m_idx);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
if (access_idx == ldsm_vec_num) {
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
}
}
}
InstructionCount::kRow * kTilesPerInstruction == 4,
"can't use ldmatrix.x4");
int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
} else {
// Note: This is not tested or used
origin_ = MatrixCoord(0, lane_id % 8);
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
CUTLASS_PRAGMA_UNROLL
@ -256,17 +256,23 @@ class WarpIteratorFromSmem {
using LoadLayout = typename platform::
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
MatrixCoord offset;
if (kOperand == Operand::kA) {
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
} else {
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx <
(InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
++access_m_idx) {
MatrixCoord offset;
if (kOperand == Operand::kA) {
offset = MatrixCoord(
access_m_idx * 16, iterations_ * InstructionShape::kColumn);
} else {
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
}
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
cutlass::arch::ldsm<LoadLayout, 4>(
access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
}
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
cutlass::arch::ldsm<LoadLayout, 4>(
access_ptr[0], ref_.data() + ref_.offset(offset));
}
};

File diff suppressed because it is too large Load Diff

View File

@ -66,6 +66,7 @@
#include "debug_utils.h"
#include "epilogue/epilogue_pipelined.h"
#include "epilogue/epilogue_rescale_output.h"
#include "gemm/custom_mma.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
#include "gemm_kernel_utils.h"
@ -77,7 +78,7 @@ using namespace gemm_kernel_utils;
namespace {
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
constexpr int getWarpsPerSmFw() {
return (
Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value
@ -92,6 +93,24 @@ static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
}
} // namespace
// If ToBatchHookType_ is supplied other than this default (which is
// never the case in the xformers library) then the user is
// defining the logic which each block uses to find its data to work on,
// with the advance_to_batch function with the following signature.
// It should return false if there is no work to do for this block.
// In general this will not work with saving for backward due to fixed layout
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
// custom inference.
struct DefaultToBatchHook {
template <typename Params>
CUTLASS_DEVICE static bool advance_to_batch(
Params&,
int64_t& /* q_start */,
int64_t& /* k_start */) {
return true;
}
};
template <
// The datatype of Q/K/V
typename scalar_t_,
@ -99,13 +118,15 @@ template <
typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kQueriesPerBlock_,
int kKeysPerBlock_,
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
// upperbound on `max(value.shape[-1], query.shape[-1])`
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
// This is quite slower on V100 for some reason
// Set to false if you know at compile-time you will never need dropout
bool kSupportsDropout_ = true,
bool kSupportsBias_ = true>
bool kSupportsBias_ = true,
typename ToBatchHookType_ = DefaultToBatchHook>
struct AttentionKernel {
enum CustomMaskType {
NoCustomMask = 0,
@ -125,11 +146,14 @@ struct AttentionKernel {
static constexpr bool kSupportsDropout = kSupportsDropout_;
static constexpr bool kSupportsBias = kSupportsBias_;
static constexpr int kKeysPerBlock = kKeysPerBlock_;
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
static constexpr int kMaxK = kMaxK_;
static constexpr bool kIsAligned = isAligned_;
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kPreloadV =
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
@ -143,66 +167,67 @@ struct AttentionKernel {
// Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
int32_t* seqstart_q_ptr = nullptr;
int32_t* seqstart_k_ptr = nullptr;
int32_t* causal_diagonal_ptr = nullptr;
int32_t* seqlen_k_ptr = nullptr;
uint32_t causal_diagonal_offset = 0;
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
output_accum_t*
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
// [num_queries, num_heads, head_dim_value]
output_accum_t* output_accum_ptr = nullptr;
// [num_heads, num_queries] - can be null
lse_scalar_t* logsumexp_ptr = nullptr;
// Scale
accum_t scale;
accum_t scale = 0.0;
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
int32_t head_dim = 0;
int32_t head_dim_value = 0;
int32_t num_queries = 0;
int32_t num_keys = 0;
int32_t num_keys_absolute = 0;
uint8_t custom_mask_type = NoCustomMask;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t q_strideM = 0;
int32_t k_strideM = 0;
int32_t v_strideM = 0;
int32_t bias_strideM = 0;
int32_t o_strideM = 0;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t bias_strideH = 0;
int32_t q_strideH = 0;
int32_t k_strideH = 0;
int32_t v_strideH = 0;
int64_t bias_strideH = 0;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int32_t bias_strideB = 0;
int64_t q_strideB = 0;
int64_t k_strideB = 0;
int64_t v_strideB = 0;
int64_t bias_strideB = 0;
int32_t num_batches;
int32_t num_heads;
int32_t num_batches = 0;
int32_t num_heads = 0;
// dropout
bool use_dropout;
unsigned long long dropout_batch_head_rng_offset;
float dropout_prob;
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
float dropout_prob = 0.0f;
#ifdef HAS_PYTORCH
at::PhiloxCudaState rng_engine_inputs;
at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
#endif
// Moves pointers to what we should process
@ -220,9 +245,17 @@ struct AttentionKernel {
head_id * num_queries * num_keys;
}
int64_t q_start, k_start;
int64_t q_start = 0, k_start = 0;
// Advance to current batch - in case of different sequence lengths
if (seqstart_q_ptr != nullptr) {
constexpr bool kToBatchHook =
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
value;
if (kToBatchHook) {
// Call out to a custom implementation.
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
return false;
}
} else if (seqstart_q_ptr != nullptr) {
assert(seqstart_k_ptr != nullptr);
seqstart_q_ptr += batch_id;
@ -285,12 +318,12 @@ struct AttentionKernel {
}
// Custom masking
if (causal_diagonal_ptr) {
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
}
if (custom_mask_type == CausalFromBottomRight) {
causal_diagonal_offset += num_keys - num_queries;
causal_diagonal_offset = num_keys - num_queries;
}
// We use num_keys_absolute to index into the rng_state
// We need this index to match between forward and backwards
num_keys_absolute = num_keys;
if (custom_mask_type == CausalFromTopLeft ||
custom_mask_type == CausalFromBottomRight) {
// the bottom row of the current block is query_start + kQueriesPerBlock
@ -323,6 +356,7 @@ struct AttentionKernel {
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
// Only worth doing if they could have been modified above.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
@ -335,8 +369,6 @@ struct AttentionKernel {
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
num_heads = warp_uniform(num_heads);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
o_strideM = warp_uniform(o_strideM);
custom_mask_type = warp_uniform(custom_mask_type);
return true;
@ -395,14 +427,19 @@ struct AttentionKernel {
ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
// uses too much smem
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
typename GemmType::Operator // Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
using Mma = typename cutlass::platform::conditional<
kSingleValueIteration,
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
DefaultThreadblockMma>::type;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
accum_t,
@ -475,14 +512,23 @@ struct AttentionKernel {
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using WarpIteratorA = typename cutlass::gemm::threadblock::
DefaultWarpIteratorAFromSharedMemory<
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
typename DefaultGemm::Mma::Policy>::WarpIterator;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage,
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
WarpIteratorA,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
@ -500,10 +546,6 @@ struct AttentionKernel {
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
@ -515,6 +557,9 @@ struct AttentionKernel {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi;
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
addition_storage;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
@ -524,7 +569,7 @@ struct AttentionKernel {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
};
union {
@ -546,7 +591,7 @@ struct AttentionKernel {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
@ -573,30 +618,33 @@ struct AttentionKernel {
if (kSupportsBias) {
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
XFORMERS_CHECK(
p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideB)");
XFORMERS_CHECK(
p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
}
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
p.q_strideM % kAlignmentQ == 0,
"query is not correctly aligned (strideM)");
XFORMERS_CHECK(
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
p.k_strideM % kAlignmentK == 0,
"key is not correctly aligned (strideM)");
XFORMERS_CHECK(
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
p.v_strideM % kAlignmentV == 0,
"value is not correctly aligned (strideM)");
XFORMERS_CHECK(
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
"query is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
"key is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
"value is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.custom_mask_type < NumCustomMaskTypes,
"invalid value for `custom_mask_type`");
@ -613,11 +661,13 @@ struct AttentionKernel {
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
@ -689,7 +739,7 @@ struct AttentionKernel {
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.mm1,
iterator_V,
thread_id(),
problem_size_1_k);
@ -733,7 +783,7 @@ struct AttentionKernel {
thread_id(),
tb_offset_B);
auto my_warp_id = warp_id();
auto my_warp_id = warp_uniform(warp_id());
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
@ -753,6 +803,8 @@ struct AttentionKernel {
if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
@ -787,7 +839,7 @@ struct AttentionKernel {
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
my_lane_id, my_warp_id, iteratorC_tile_offset);
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
@ -811,7 +863,7 @@ struct AttentionKernel {
(query_start + p.causal_diagonal_offset)) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
my_lane_id, my_warp_id, iteratorC_tile_offset);
int32_t last_col;
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
@ -830,30 +882,23 @@ struct AttentionKernel {
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
}));
}));
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
accum_o,
accum,
mi,
m_prime,
s_prime,
out_rescale,
shared_storage.addition_storage,
my_lane_id,
thread_id(),
my_warp_id,
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@ -904,7 +949,7 @@ struct AttentionKernel {
curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead(
static_cast<unsigned long long>(
(query_start + thread_i) * p.num_keys +
(query_start + thread_i) * p.num_keys_absolute +
(iter_key_start + thread_start_j)),
&curand_state);
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
@ -958,12 +1003,14 @@ struct AttentionKernel {
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
// operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
(int)my_warp_id,
(int)my_lane_id);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
@ -976,6 +1023,7 @@ struct AttentionKernel {
}
if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
@ -1027,12 +1075,12 @@ struct AttentionKernel {
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
my_warp_id,
my_lane_id);
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
@ -1076,12 +1124,13 @@ struct AttentionKernel {
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}
@ -1091,8 +1140,9 @@ struct AttentionKernel {
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
@ -1101,20 +1151,21 @@ struct AttentionKernel {
}
}
template <
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
template <typename WarpIteratorC>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
addition_storage,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
@ -1135,12 +1186,11 @@ struct AttentionKernel {
kWarpSize>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
@ -1154,46 +1204,64 @@ struct AttentionKernel {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
if (accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
atomicMaxFloat(&mi[accum_m], max);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
// Doing this `exp` is quite expensive. Let's
// split it across the warps
bool restore_mi_to_minus_inf = false;
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
auto m_prime_id = m_prime[id];
auto mi_id = mi[id];
bool changed = m_prime_id < mi_id; // `false` if both are -inf
if (changed) {
auto m_prime_exp = exp2f(m_prime_id - mi_id);
out_rescale[id] = m_prime_exp;
s_prime[id] *= m_prime_exp;
} else {
// Only when bias is enabled, it's possible that all the first values
// of attention are masked to `-inf`. In that case we want to avoid
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
if (kSupportsBias &&
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
restore_mi_to_minus_inf = true;
mi[id] = 0.0f;
}
out_rescale[id] = 1.0f;
}
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
if (kKeepOutputInRF && !is_first) {
accum_t line_rescale;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag_o[idx] = frag_o[idx] * line_rescale;
},
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m) { mi_row = mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
@ -1205,10 +1273,30 @@ struct AttentionKernel {
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
// NOTE: we could atomically add `total_row` to `s_prime`, but
// it's faster (and deterministic) to avoid atomics here
addition_storage
[accum_m + kQueriesPerBlock * tile_offset.column()] =
total_row;
}
});
}
__syncthreads();
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
accum_t total_row = s_prime[id];
if (restore_mi_to_minus_inf) {
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
} else {
m_prime[id] = mi[id];
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
s_prime[id] = total_row;
}
}
static CUTLASS_DEVICE int8_t lane_id() {

View File

@ -0,0 +1,112 @@
from typing import List
import torch
import subprocess
import sys
import tempfile
import os
import numpy as np
TORCH_DTYPE_NAME = {
torch.float32: "f32",
torch.float16: "f16",
torch.bfloat16: "b16"
}
NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()}
def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor:
# PyTorch >= 2.0
if hasattr(tensor, 'untyped_storage'):
return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage())
return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped())
class PipedSubprocess:
def __init__(self, binary: str) -> None:
self.binary = binary
self.tempdir_ctx = tempfile.TemporaryDirectory()
def __enter__(self) -> "PipedSubprocess":
self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0)
self.tempdir = self.tempdir_ctx.__enter__()
self.file_counter = 0
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb)
def temp_filename(self, suffix: str) -> str:
self.file_counter += 1
return os.path.join(self.tempdir, f"{self.file_counter}{suffix}")
def write(self, *args) -> None:
for a in args:
self.subp.stdin.write(str(a) + " ")
def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None:
print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}")
tensor_u8 = _tensor_from_storage(tensor, torch.uint8)
self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0])
filename = self.temp_filename(f"{name}.tensor")
assert tensor.storage_offset() == 0
with open(filename, "wb+") as fd:
fd.write(bytes(tensor_u8.numpy()))
self.write("file", filename)
self.write("tensor_end")
for stride_name, stride_value in zip(stride_names, tensor.stride()):
self.write(stride_name, stride_value)
def readTensor(self, name, stride_name, shape) -> torch.Tensor:
tmpfile = self.temp_filename(f"{name}.tensor")
self.write("tmpfile", tmpfile)
self.readExpect("tensor_begin")
dtype_str, name = self.read().split(":")
print(f"C++->Py : {dtype_str}:{name}")
u8len = int(self.read())
dtype = NAME_TORCH_DTYPE[dtype_str]
self.readExpect("file")
self.readExpect(tmpfile)
with open(tmpfile, "rb") as fd:
data = fd.read(u8len)
# `np.array` is not strictly needed, but avoids a torch warning
tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len)
self.readExpect("tensor_end")
tensor = _tensor_from_storage(tensor_u8, dtype)
strides = []
for sn in stride_name:
self.readExpect(sn)
strides.append(int(self.read()))
if len(strides) != shape:
strides.append(1)
assert len(strides) == len(shape), name
return torch.as_strided(tensor, shape, strides)
def readNamed(self, name: str):
self.readExpect(name)
return self.read()
def readExpect(self, what: str) -> None:
r = self.read()
if r != what:
raise ValueError(f"Read {r} but expected {what}")
def read(self):
read_all = []
# Skip initial whitespace
while True:
r = self.subp.stdout.read(1)
if r not in [' ', "\n"]:
read_all.append(r)
break
# Read data
while True:
r = self.subp.stdout.read(1)
if r in [' ', "\n"]:
break
read_all.append(r)
return ''.join(read_all)

View File

@ -29,6 +29,8 @@
*
**************************************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"

View File

@ -434,14 +434,6 @@ class gen_device:
" if (result != cudaSuccess) {\n" + \
" return Status::kErrorInternal;\n" + \
" }\n" + \
"\n" + \
" result = cudaFuncSetAttribute(\n" + \
" Kernel<B2bGemmKernel>,\n" + \
" cudaFuncAttributePreferredSharedMemoryCarveout, 100);\n" + \
"\n" + \
" if (result != cudaSuccess) {\n" + \
" return Status::kErrorInternal;\n" + \
" }\n" + \
" }\n" + \
" cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);\n" + \
" result = cudaGetLastError();\n" + \

View File

@ -331,7 +331,7 @@ class gen_Kernel:
operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n"
operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n"
operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n"
operator_code += " " + "int lane_idx = threadIdx.x % 32;\n"
for i in range (self.b2bnum - 1):

View File

@ -159,7 +159,7 @@ class DualGemm {
using Mma0 = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape,
ThreadblockShape, WarpShape,
InstructionShape, Stages, Operator>::ThreadblockMma;
using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB,
@ -348,7 +348,7 @@ public:
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);

View File

@ -167,10 +167,10 @@ bool run_nonfused_gemm_f16_sm80() {
std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n";
bool pass = nonFusedGemm.run(
problem_size,
alpha0,
beta0,
alpha1,
problem_size,
alpha0,
beta0,
alpha1,
beta1,
true /* is_profiling */
);
@ -248,10 +248,10 @@ bool run_fused_gemm_f16_sm80_shmem() {
std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n";
bool passed = fusedGemm.run(
problem_size,
alpha0,
beta0,
alpha1,
problem_size,
alpha0,
beta0,
alpha1,
beta1
);
@ -301,11 +301,11 @@ bool run_batched_fused_gemm_f16_sm80_shmem() {
std::cout << "Running Batched Fused FP16 TN GEMMs + Epilogue2...\n";
bool passed = fusedGemm.run(
batch_problem_size,
alpha0,
beta0,
alpha1,
beta1,
batch_problem_size,
alpha0,
beta0,
alpha1,
beta1,
kBatchCount,
false, /* broadcast_b1 */
false /* is_profiling */
@ -358,11 +358,11 @@ bool run_broadcast_fused_gemm_f16_sm80_shmem() {
std::cout << "Running Broadcast Fused FP16 TN GEMMs + Epilogue2...\n";
bool passed = fusedGemm.run(
problem_size,
alpha0,
beta0,
alpha1,
beta1,
problem_size,
alpha0,
beta0,
alpha1,
beta1,
1, /* batch_count */
true, /* broadcast_b1 */
true /* is_profiling */
@ -415,11 +415,11 @@ bool run_batched_broadcast_fused_gemm_f16_sm80_shmem() {
std::cout << "Running Batch Broadcast Fused FP16 TN GEMMs + Epilogue2...\n";
bool passed = fusedGemm.run(
batch_problem_size,
alpha0,
beta0,
alpha1,
beta1,
batch_problem_size,
alpha0,
beta0,
alpha1,
beta1,
kBatchCount,
true, /* broadcast_b1 */
false /* is_profiling */
@ -444,11 +444,11 @@ int main() {
};
std::string test_name = (
"dual-gemm f16 bias=" +
std::to_string(kUseBias) +
" split_k_serial=" +
"dual-gemm f16 bias=" +
std::to_string(kUseBias) +
" split_k_serial=" +
std::to_string(kSplitKSerial) +
" batch_count=" +
" batch_count=" +
std::to_string(kBatchCount)
);

View File

@ -45,6 +45,7 @@
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
@ -356,13 +357,13 @@ struct NonFusedDualGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
@ -564,22 +565,22 @@ struct DualFusedGemmRun
cutlass::HostTensor<
typename DualGemm::ElementA,
typename DualGemm::LayoutA> tensor_A0(
std::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
cutlass::platform::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k()));
cutlass::HostTensor<
typename DualGemm::ElementB,
typename DualGemm::LayoutB0> tensor_B0(
std::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_C0(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
cutlass::HostTensor<
@ -589,22 +590,22 @@ struct DualFusedGemmRun
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_D0(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> reference_D0(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
cutlass::HostTensor<
typename DualGemm::ElementB,
typename DualGemm::LayoutB1> tensor_B1(
std::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
if (broadcast_b1) {
tensor_B1.resize({problem_size.k(), batch_count});
@ -613,8 +614,8 @@ struct DualFusedGemmRun
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_C1(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
cutlass::HostTensor<
@ -624,29 +625,29 @@ struct DualFusedGemmRun
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_D1(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> tensor_D2(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> reference_D1(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
cutlass::HostTensor<
typename DualGemm::ElementC,
typename DualGemm::LayoutC> reference_D2(
std::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -712,16 +713,16 @@ struct DualFusedGemmRun
ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
}
typename DualGemm::Arguments arguments{
(batch_count > 1 ?
cutlass::gemm::DualGemmMode::kBatched :
(batch_count > 1 ?
cutlass::gemm::DualGemmMode::kBatched :
cutlass::gemm::DualGemmMode::kGemm),
problem_size,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
ref_B0,
DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
(broadcast_b1 ?
typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
(broadcast_b1 ?
typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
tensor_B1.device_ref()),
ref_B1,
DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
@ -793,15 +794,15 @@ struct DualFusedGemmRun
using GemmUniversal0 = cutlass::gemm::device::GemmUniversal<
typename DualGemm::ElementA, typename DualGemm::LayoutA,
typename DualGemm::ElementB, typename DualGemm::LayoutB0,
typename DualGemm::ElementC, typename DualGemm::LayoutC,
typename DualGemm::ElementC, typename DualGemm::LayoutC,
ElementAccumulator
>;
GemmUniversal0 reference_gemm0;
typename GemmUniversal0::Arguments args0 {
(batch_count > 1 ?
cutlass::gemm::GemmUniversalMode::kBatched :
(batch_count > 1 ?
cutlass::gemm::GemmUniversalMode::kBatched :
cutlass::gemm::GemmUniversalMode::kGemm),
problem_size,
batch_count,
@ -828,15 +829,15 @@ struct DualFusedGemmRun
using GemmUniversal1 = cutlass::gemm::device::GemmUniversal<
typename DualGemm::ElementA, typename DualGemm::LayoutA,
typename DualGemm::ElementB, typename DualGemm::LayoutB1,
typename DualGemm::ElementC, typename DualGemm::LayoutC,
typename DualGemm::ElementC, typename DualGemm::LayoutC,
ElementAccumulator
>;
GemmUniversal1 reference_gemm1;
typename GemmUniversal1::Arguments args1 {
(batch_count > 1 ?
cutlass::gemm::GemmUniversalMode::kBatched :
(batch_count > 1 ?
cutlass::gemm::GemmUniversalMode::kBatched :
cutlass::gemm::GemmUniversalMode::kGemm),
problem_size,
batch_count,
@ -861,7 +862,7 @@ struct DualFusedGemmRun
CUTLASS_CHECK(status);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}

View File

@ -300,7 +300,7 @@ struct DualGemm {
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A0 = static_cast<ElementA *>(params.ref_A0.data());
ElementA *ptr_A0 = static_cast<ElementA *>(params.ref_A0.data());
ElementB *ptr_B0 = static_cast<ElementB *>(params.ref_B0.data());
ElementB *ptr_B1 = static_cast<ElementB *>(params.ref_B1.data());
@ -309,7 +309,7 @@ struct DualGemm {
//
if (params.mode == DualGemmMode::kGemm) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
@ -364,7 +364,7 @@ struct DualGemm {
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
@ -413,11 +413,11 @@ struct DualGemm {
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C0 = static_cast<ElementC *>(params.ref_C0.data());
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
ElementC *ptr_D0 = static_cast<ElementC *>(params.ref_D0.data());
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
ElementC *ptr_D2 = static_cast<ElementC *>(params.ref_D2.data());
ElementC *ptr_C0 = static_cast<ElementC *>(params.ref_C0.data());
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
ElementC *ptr_D0 = static_cast<ElementC *>(params.ref_D0.data());
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
ElementC *ptr_D2 = static_cast<ElementC *>(params.ref_D2.data());
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
@ -425,7 +425,7 @@ struct DualGemm {
if (params.mode == DualGemmMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();

View File

@ -759,13 +759,10 @@ public:
accum1 = plus_accum(accum1, tmp_accum1);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

View File

@ -233,6 +233,17 @@ struct Options {
return false;
}
// Filter size passed through command line does not match filter size template parameter
if (filter_size.h() != FilterShape::kRow || filter_size.w() != FilterShape::kColumn) {
std::cerr << "Filter size passed in (" << filter_size.h() << "x" << filter_size.w() << ") "
<< "must match the FilterShape template parameter of the convolution "
<< "(" << FilterShape::kRow << "x" << FilterShape::kColumn << "). "
<< "To use the filter shape passed in, change the FilterShape template "
<< "parameter and recompile this example."
<< std::endl;
return false;
}
return true;
}
@ -319,9 +330,9 @@ struct Options {
"table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 "
<< "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=32 "
"--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n"
<< "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 "
<< "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=1 "
"--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n";
return out;
@ -515,14 +526,13 @@ Result profile_convolution(Options const &options) {
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue> >(problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_d.host_ref(),
options.alpha,
options.beta);
ElementAccumulator >(problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_d.host_ref(),
options.alpha,
options.beta);
// Check if output from CUTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();

View File

@ -33,3 +33,7 @@ cutlass_example_add_executable(
ampere_gemm_universal_streamk.cu
)
cutlass_example_add_executable(
47_ampere_gemm_universal_streamk_broadcast
ampere_gemm_universal_streamk_broadcast.cu
)

View File

@ -495,7 +495,7 @@ int main(int argc, const char **argv)
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
// Fill matrix A on host with uniform-random data [4, -4]
// Fill matrix A on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
@ -503,7 +503,7 @@ int main(int argc, const char **argv)
ElementA(-2),
0);
// Fill matrix B on host with uniform-random data [4, -4]
// Fill matrix B on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
@ -511,7 +511,7 @@ int main(int argc, const char **argv)
ElementB(-2),
0);
// Fill matrix C on host with uniform-random data [4, -4]
// Fill matrix C on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c.host_view(),
1,

View File

@ -0,0 +1,738 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/***************************************************************************************************
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
"classic data-parallel" and "Split-K" decompositions + residual add.
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
Requires NVIDIA Ampere or newer device (SM80+).
- To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100)
cutlass$ sudo nvidia-smi -pm 1 -i 0
cutlass$ sudo nvidia-smi -i 0 -pl 400
cutlass$ sudo nvidia-smi -i 0 -lgc 1005
- Build and run:
cutlass$ mkdir build
cutlass$ cd build
cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80
cutlass/build$ make 47_ampere_gemm_universal_streamk_broadcast
cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk_broadcast
- Reset clocks when done:
cutlass$ sudo nvidia-smi -rgc
**************************************************************************************************/
#include <iostream>
#include <string>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_foreach.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C1/C2/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = cutlass::half_t; // Element type for output matrix operands
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation
using ElementCompute = cutlass::half_t; // Element type for compute
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT
// Residual block configuration
// Epilogue output operator
/// Using LinearCombinationResidualBlock
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput, // Element type for output matrix
ElementAccumulator, // Element type from internal accumulation
ElementCompute, // Element type from internal accumulation
ElementC, // Element type for C1/C2/D matrix operands
AlignmentC, // Memory access granularity of C and D matrix in units of elements
cutlass::epilogue::thread::Identity, // Activation
cutlass::plus, // Binary operation 1
cutlass::epilogue::thread::Identity, // Unary operation
cutlass::plus // Binary operation 2
>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
// Classic data-parallel device GEMM implementation type
using DeviceGemmBasic = cutlass::gemm::device::GemmUniversalWithBroadcast<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
NumStages,
AlignmentA,
AlignmentB>;
// StreamK device GEMM implementation type with EVT
using namespace cute;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
EVTEpilogueStages
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_0, _1, int32_t> // StrideMNL
>;
using C1 = cutlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, ElementC,
cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using C2 = cutlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, ElementC,
cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Accum,
Bias>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
Compute1,
EVTCompute0,
C1>;
using Compute2 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementOutput, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute2 = cutlass::epilogue::threadblock::Sm80EVT<
Compute2,
EVTCompute1,
C2>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute2>;
using EVTKernelStreamK =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages
>::GemmKernel;
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
/// Command line options parsing
struct Options
{
std::string command_name;
bool help;
cutlass::gemm::GemmCoord problem_size;
float alpha;
float beta;
int split_k_factor;
int avail_sms;
int iterations;
bool real;
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
cutlass::HostTensor<ElementC, LayoutC> tensor_c1;
cutlass::HostTensor<ElementC, LayoutC> tensor_c2;
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
cutlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
cutlass::HostTensor<ElementC, LayoutC> tensor_Vector;
// cutlass::HostTensor<ElementC, LayoutC> tensor_Tensor;
Options(std::string command_name) :
command_name(command_name),
help(false),
problem_size({2048, 2048, 2048}),
alpha(1.0f),
beta(1.0f),
split_k_factor(1),
avail_sms(-1), // Number of device SMs to use is unlimited
real(false),
iterations(10000)
{}
bool valid() const
{
return true;
}
void parse(int argc, char const **args)
{
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("split", split_k_factor);
cmd.get_cmd_line_argument("iterations", iterations);
real = cmd.check_cmd_line_flag("real");
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const
{
out
<< "Performs a GEMM computation.\n"
<< "\n"
<< "Options:\n"
<< "\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --split=<int> Split-K factor to emulate\n\n"
<< " --real If specified, initializes with real values instead of whole numbers. Errors are to be expected.\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
typename DeviceGemmBasic::Arguments args_from_options(
const DeviceGemmBasic &device_gemm,
const Options &options,
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector /*,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor */
)
{
return typename DeviceGemmBasic::Arguments(
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
{ // epilogue parameters
ElementAccumulator(options.alpha),
ElementAccumulator(options.beta)
},
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c1.device_data(), // ptr_C1
tensor_c2.device_data(), // ptr_C2
tensor_d.device_data(), // ptr_D
tensor_Vector.device_data(), // ptr_Vector
/* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
options.problem_size.mn().product(), // batch_stride_C1
options.problem_size.mn().product(), // batch_stride_C2
options.problem_size.mn().product(), // batch_stride_D
options.problem_size.mn().product(), // batch_stride_Vector
options.problem_size.mn().product(), // batch_stride_Tensor
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c1.layout().stride(0), // stride_c1
tensor_c2.layout().stride(0), // stride_c2
tensor_d.layout().stride(0), // stride_d
/*tensor_Vector.layout().stride(0)*/0, // stride_Vector
/*tensor_Tensor.layout().stride(0)*/0); // stride_Tensor
}
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
typename DeviceGemmStreamK::Arguments args_from_options(
const DeviceGemmStreamK &device_gemm,
const Options &options,
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector/*,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor*/
)
{
typename EVTD::Arguments callback_args{
{
{
{
{}, // Accum
{tensor_Vector.device_data(), ElementC(0), {_0{}, _1{}, int32_t(options.problem_size.n())}}, // Bias
{} // Compute0
}, // EVTCompute0
{tensor_c1.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C1
{} // Compute1
}, // EVTCompute1
{tensor_c2.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C2
{} // Compute2
}, // EVTCompute2
{tensor_d.device_data(), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // D
}; // EVTD
return typename DeviceGemmStreamK::Arguments(
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
0, // stride_c (unused)
0, // stride_d (unused)
options.avail_sms); // avail_sms
}
/// Execute a given example GEMM computation
template <typename DeviceGemmT>
Result run(std::string description, Options &options)
{
// Display test description
std::cout << std::endl << description << std::endl;
// Zero-initialize test output matrix D
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
options.tensor_d.sync_device();
// Instantiate CUTLASS kernel depending on templates
DeviceGemmT device_gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
auto arguments = args_from_options(device_gemm, options,
options.tensor_a, options.tensor_b, options.tensor_c1, options.tensor_c2, options.tensor_d,
options.tensor_Vector/*, options.tensor_Tensor*/);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
CUTLASS_CHECK(device_gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(device_gemm());
// Copy output data from CUTLASS and reference kernel to host for comparison
options.tensor_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = cutlass::reference::host::TensorEquals(
options.tensor_d.host_view(),
options.tensor_ref_d.host_view());
double err = cutlass::reference::host::TensorRelativeErrorMetric(
options.tensor_d.host_view(),
options.tensor_ref_d.host_view());
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(device_gemm());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
}
// TODO: uncomment when results match
//if (!result.passed) {
// exit(-1);
//}
return result;
}
/// Program entrypoint
int main(int argc, const char **argv)
{
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
// Current device must must have compute capability at least 80
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!((props.major * 10 + props.minor) >= 80))
{
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
// Parse commandline options
Options options("ampere_streamk_broadcast_gemm");
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
std::cout <<
options.iterations << " timing iterations of " <<
options.problem_size.m() << " x " <<
options.problem_size.n() << " x " <<
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
//
// Initialize GEMM datasets
//
// Initialize tensors using CUTLASS helper functions
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
options.tensor_c1.resize(options.problem_size.mn()); // <- Create matrix C1 with dimensions M x N
options.tensor_c2.resize(options.problem_size.mn()); // <- Create matrix C2 with dimensions M x N
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
options.tensor_Vector.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions N x 1
// options.tensor_Tensor.resize(options.problem_size.mn()); // <- Create T matrix with dimensions M x N
int _init_bits = options.real ? -1 : 0;
// Fill matrix A on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
ElementA(2),
ElementA(-2), _init_bits);
// Fill matrix B on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
ElementB(2),
ElementB(-2), _init_bits);
// Fill matrix C1 on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c1.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
// Fill matrix C2 on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c2.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_Vector.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
//
// Compute reference output
//
// Copy data from host to GPU
options.tensor_a.sync_device();
options.tensor_b.sync_device();
options.tensor_c1.sync_device();
options.tensor_c2.sync_device();
options.tensor_Vector.sync_device();
// options.tensor_Tensor.sync_device();
// Zero-initialize reference output matrix D
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
options.tensor_ref_d.sync_device();
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
options.problem_size,
ElementAccumulator(options.alpha),
options.tensor_a.device_ref(),
options.tensor_b.device_ref(),
ElementAccumulator(options.beta),
options.tensor_c1.device_ref(),
options.tensor_ref_d.device_ref());
// Wait for kernels to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Copy output data from reference kernel to host for comparison
options.tensor_ref_d.sync_host();
// Add broadcast vector (without multiplier)
// This is only possible because BinaryOp is addition, and UnaryOps are identity.
// This makes the addition of broadcast vector commutable.
/// identity(plus(identity(alpha * (a * b) + v), beta * c)) ==
/// alpha * a * b + v + beta * c ==
/// (alpha * a * b + beta * c) + v ==
/// GEMM(a, b, c) + v
// Vector broadcast on host
for (int i=0; i < options.problem_size.m(); ++i) {
for (int j=0; j < options.problem_size.n(); ++j) {
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_Vector.host_view().ref().at({0, j});
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_c2.host_view().ref().at({i, j});
}
}
// Sync back with device just in case
options.tensor_ref_d.sync_device();
//
// Evaluate CUTLASS kernels
//
// Test default operation
if (options.split_k_factor == 1)
{
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
options.split_k_factor++; // Increment splitting factor for next evaluation
}
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
Result basic_splitk = run<DeviceGemmBasic>(
std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
Result streamk_splitk = run<DeviceGemmStreamK>(
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
return 0;
}

View File

@ -60,6 +60,7 @@
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
@ -95,12 +96,13 @@ constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // M
// C/D matrix configuration
using ElementC = float; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size
using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder
@ -110,15 +112,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TilesShape, ClusterShape,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, // Indicates ProblemShape
@ -286,10 +293,10 @@ bool initialize_block(
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{}));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{}));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{}));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{}));
block_A.reset(options.m * options.k);
block_B.reset(options.k * options.n);
@ -308,11 +315,8 @@ typename Gemm::Arguments args_from_options(const Options &options)
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k},
block_A.get(),
stride_A,
block_B.get(),
stride_B,
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}
{block_A.get(), stride_A, block_B.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
return arguments;
@ -320,7 +324,7 @@ typename Gemm::Arguments args_from_options(const Options &options)
bool verify(const Options &options) {
cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k}));
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.n, options.k}));
cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n}));
cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n}));
cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n}));

View File

@ -77,15 +77,28 @@
will fit in shared memory given the types of operands and the thread block shape, rather than simply using
a single default value.
Note that one does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide
every template parameter to the gemm::collective::CollectiveMma. Specifying every template parameter in this
manner remains the primary API for using CUTLASS 3 kernels. The CollectiveBuilder is simply meant to be
a convenience interface.
CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of
the collective is specified via the schedule tags that corresond to the underlying collective's
dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto`
are special cases of these schedules that allow the builder to also decide the dispatch policy for you,
therefore letting the builder pick the collective specialization.
Note also that, while the selections made by CollectiveBuilder attempt to maximize performance, this is not
a guarantee. Furthermore, the behavior of the CollectiveBuilder when `Auto` parameters are provided is subject
to change in future CUTLASS releases -- do not rely on `Auto` if you require a specific scheduling policy and/or
stage count to be used.
CUTLASS builders make an attempt to pick the best schedule when `Auto` is provided such that the
assembled collectives have the best performance, but this is not a guarantee. A user relying on `Auto`
may get a free performance upgrade with newer CUTLASS releases in case we can provide more optimized
implementations that the builder can transparently assemble for `Auto`. But a user should not rely on
`Auto` if they require a specific scheduling policy and/or stage count to be used.
If a user decides to let the builders pick the collective specialization via `Auto` schedules,
they must be used for both mainloop and epilogue alike to ensure compatibility between the
chosen collectives. Additionally, if a user chooses to opt in to a specific schedule, non-`Auto`
schedules must be used for both mainloop and epilogue builder schedules, and these schedules
must be compatible.
One does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide
every template parameter to the `gemm::collective::CollectiveMma`. Specifying every template parameter
in this manner remains the primary API for using CUTLASS 3 kernels. `CollectiveBuilder`s are
simply meant to be a convenience interface.
Details of this example
-----------------------
@ -93,8 +106,15 @@
This example also illustrates how CUTLASS 3 GEMMs targeting Hopper automatically support batched GEMMs by simply
extending the problem size with an additional tensor rank.
CUTLASS 3.2 provides initial support for epilogue visitor trees (EVT) for the TMA warp-specialized collective.
EVTs allow users to define their own customized epilogue fusion patterns without having to write a new
collective epilogue. This is done by representing the fusion as a compute graph, where each node is one of a
fundamental set of load, store, or compute operations. These operations are either elementwise for tensor
inputs/outputs, broadcasts for vector/scalar inputs, or reductions for vector/scalar outputs.
This example shows how users can define their own custom EVT and use it with the CollectiveBuilder.
Example usage:
$ ./examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder \
$ ./examples/49_hopper_with_collective_builder/49_collective_builder \
--m=2048 --n=2048 --k=2048 --l=2
*/
@ -108,8 +128,10 @@
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
@ -160,7 +182,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "49_hopper_gemm_schedules_with_collective_builder\n\n"
out << "49_hopper_with_collective_builder\n\n"
<< " This example showcases the use of CUTLASS's collective operation builders to easily construct\n"
<< " performant kernels targeting NVIDIA's Hopper architecture.\n\n"
<< "Options:\n\n"
@ -212,16 +234,30 @@ bool initialize_block(
// operation builders by specializing the GEMM only on the kernel schedule it will use and the
// number of pipeline stages.
//
// For either option, one can use a special `Auto` type that tells the CollectiveBuilder
// One can use a special `Auto` type that tells the CollectiveBuilder
// to select an appropriate value on its own. The CollectiveBuilder will attempt to select
// values that will result in the most-performant kernel, but this is not a guarantee. Furthermore,
// the behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases
// configurations that will result in the most-performant kernel, but this is not a guarantee.
//
// If relying on 'Auto' schedules, all builders must use the 'Auto' schedule to ensure compatiblity.
// For example, if `KernelScheduleAuto` is used for the mainloop builder, `EpilogueScheduleAuto` must
// be used for the epilogue builder.
//
// Furthermore, if an override schedule is selected, both epilogue and mainloop schedules must
// be specifically opt into a compatible selection.
//
// Behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases
// -- do not rely on `Auto` if you require a specific scheduling policy.
template <
// Type of kernel schedule to generate
class KernelScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
// Type of epilogue schedule to generate
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
// Number of pipeline stages to use
class StageCountType = cutlass::gemm::collective::StageCountAuto
class StageCountType = cutlass::gemm::collective::StageCountAuto,
// Type of tile scheduler to use
class TileSchedulerType = cutlass::gemm::PersistentScheduler,
// Do we use custom epilogue visitor tree (EVT) fusion
bool UseCustomEVT = false
>
struct ExampleRunner {
@ -230,27 +266,72 @@ struct ExampleRunner {
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;
static constexpr int kAlignmentA = 8;
static constexpr int kAlignmentB = 8;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementScalar = float;
// 16B alignment lets us use TMA
static constexpr int AlignmentA = 16 / sizeof(ElementA);
static constexpr int AlignmentB = 16 / sizeof(ElementB);
static constexpr int AlignmentC = 16 / sizeof(ElementC);
static constexpr int AlignmentD = 16 / sizeof(ElementD);
static_assert(not UseCustomEVT ||
(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>),
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
// EVTs can be constructed by composing the fundamental load/store/compute visitor operations defined in include/cutlass/epilogue/fusion
// For more complex examples of EVT construction please refer to include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
using CustomEVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>
>;
// A predefined set of fusion operations (implemented with EVT) are supported by the TMA warp-specialized epilogue.
// Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp
// to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts.
// These tags also provide additional metadata that can be queried at compile time.
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, kAlignmentA,
cutlass::half_t, LayoutB, kAlignmentB,
float,
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
StageCountType,
KernelScheduleType
Shape<_128,_128,_64>, Shape<_1,_1,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueScheduleType,
cute::conditional_t<UseCustomEVT, CustomEVT, DefaultOperation>
>::CollectiveOp;
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutD>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 1, float, float>>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
cute::conditional_t<cute::is_same_v<StageCountType, cutlass::gemm::collective::StageCountAuto>,
cutlass::gemm::collective::StageCountAutoCarveout<(int)sizeof(typename CollectiveEpilogue::SharedStorage)>,
StageCountType>,
MainloopScheduleType
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
CollectiveEpilogue,
TileSchedulerType
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
@ -262,10 +343,10 @@ struct ExampleRunner {
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideA>());
using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B<StrideB>());
using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideC>());
using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideD>());
using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;
//
// Data members
@ -281,8 +362,8 @@ struct ExampleRunner {
cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
cutlass::DeviceAllocation<typename Gemm::ElementD> block_D;
cutlass::DeviceAllocation<typename Gemm::ElementD> block_ref_D;
//
// Methods
@ -298,15 +379,15 @@ struct ExampleRunner {
cutlass::reference::device::GemmComplex(
{M, N, K},
typename Gemm::EpilogueOutputOp::ElementCompute(alpha),
ElementScalar(alpha),
ref_A,
cutlass::ComplexTransform::kNone,
ref_B,
cutlass::ComplexTransform::kNone,
typename Gemm::EpilogueOutputOp::ElementCompute(beta),
ElementScalar(beta),
ref_C,
ref_D,
typename Gemm::EpilogueOutputOp::ElementAccumulator(0.f),
ElementAccumulator(0),
L, // batch_count
M * K, // batch_stride_A
K * N, // batch_stride_B
@ -332,10 +413,10 @@ struct ExampleRunner {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
block_A.reset(M * K * L);
block_B.reset(K * N * L);
@ -356,14 +437,37 @@ struct ExampleRunner {
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
block_A.get(),
stride_A,
block_B.get(),
stride_B,
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
{block_A.get(), stride_A, block_B.get(), stride_B},
{{}, // epilogue.thread
block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
// Custom EVT fusions will have nested unnamed args, the structure of which
// can be deduced from the type definition of the EVT.
// Each node's arguments has the recursive structure of
// {first_child_args, ..., last_child_args, op_args},
// For more complex examples of EVT initialization please refer to
// include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
if constexpr (UseCustomEVT) {
arguments.epilogue.thread =
{ // ternary op : beta * C + (alpha * acc)
{{options.beta}}, // leaf op+args : beta
{}, // leaf op+args : C
{ // binary op : alpha * acc
{{options.alpha}}, // leaf op+args : alpha
{}, // leaf op+args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
}; // end ternary op
}
// Pre-defined fusions will have flat, named args for user-friendlyness
else {
arguments.epilogue.thread.alpha = options.alpha;
arguments.epilogue.thread.beta = options.beta;
}
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
@ -477,42 +581,69 @@ int main(int argc, char const **args) {
// selected and the maximum number of stages that can fit in shared memory will be selected.
//
// This example is equivalent to declaring
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, cutlass::gemm::collective::StageCountAuto>
// ExampleRunner<
// cutlass::gemm::collective::KernelScheduleAuto,
// cutlass::epilogue::collective::EpilogueScheduleAuto,
// cutlass::gemm::collective::StageCountAuto>
// Each of the `Auto` types indicate that the CollectiveBuilder should determine the scheduling policy and
// stage count. Note that the behavior of the CollectiveBuilder with `Auto` parameters is subject to change
// -- do not rely on `Auto` if you require a specific scheduling policy.
// If you opt in to a non-'Auto' schedule, make sure all collectives are built using specific, compatible schedules.
ExampleRunner<> auto_schedule_auto_stage_runner;
passed = auto_schedule_auto_stage_runner.run(options, hw_info);
print_result("Automatically-selected schedule and stage count", passed);
// One can override the stage count used in the GEMM by replacing cutlass::gemm::collective::StageCountAuto
// with the number of stages to use (5 in this case).
ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto, _5> auto_schedule_5_stage_runner;
ExampleRunner<
cutlass::gemm::collective::KernelScheduleAuto,
cutlass::epilogue::collective::EpilogueScheduleAuto,
_5> auto_schedule_5_stage_runner;
passed = auto_schedule_5_stage_runner.run(options, hw_info);
print_result("Automatically-selected schedule with 5 stages", passed);
// One can also override the scheduling policy to use. In this case, use the KernelTma scheduling
// policy, which specifies that the Hopper TMA feature should be used.
ExampleRunner<cutlass::gemm::KernelTma> tma_schedule_auto_stage_runner;
// policy, which specifies that the Hopper TMA feature should be used, and we also use an epilogue
// that does not use any shared memory.
ExampleRunner<cutlass::gemm::KernelTma, cutlass::epilogue::NoSmemWarpSpecialized> tma_schedule_auto_stage_runner;
passed = tma_schedule_auto_stage_runner.run(options, hw_info);
print_result("TMA schedule with automatically-selected stage count", passed);
// Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized
// scheduling policy.
//
// Note that, as of the CUTLASS 3.0 release, this is the default scheduling policy
// used by the CollectiveBuilder, so this declaration is equivalent to ExampleRunner<> and
// ExampleRunner<cutlass::gemm::collective::KernelScheduleAuto>. However, this default is subject to
// change in future releases -- do not rely on `Auto` if you require a specific scheduling policy.
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized> ws_schedule_auto_stage_runner;
// scheduling policy, and an epilogue that does not use any shared memory.
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized, cutlass::epilogue::NoSmemWarpSpecialized> ws_schedule_auto_stage_runner;
passed = ws_schedule_auto_stage_runner.run(options, hw_info);
print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed);
// Finally, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized
// scheduling policy, leveraging persistent thread blocks.
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecializedPersistent> ws_persistent_schedule_auto_stage_runner;
passed = ws_persistent_schedule_auto_stage_runner.run(options, hw_info);
print_result("Persistent warp-specialized TMA schedule with automatically-selected stage count", passed);
// Here, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized
// scheduling policy, TMA-based epilogue, leveraging persistent thread blocks.
ExampleRunner<
cutlass::gemm::KernelTmaWarpSpecializedPingpong,
cutlass::epilogue::TmaWarpSpecialized> ws_pingpong_schedule_auto_stage_runner;
passed = ws_pingpong_schedule_auto_stage_runner.run(options, hw_info);
print_result("Ping-pong warp-specialized TMA schedule with automatically-selected stage count", passed);
// Here, we override the scheduling policy to use stream-K problem decomposition atop the cooperative
// warp-specialized scheduling policy. This kernel continues to leverage persistent thread blocks
// as well aso TMA in both the mainloop and epilogue.
ExampleRunner<
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::StreamKScheduler> ws_cooperative_stream_k_schedule_auto_stage_runner;
passed = ws_cooperative_stream_k_schedule_auto_stage_runner.run(options, hw_info);
print_result("Cooperative warp-specialized TMA schedule using stream-K with automatically-selected stage count", passed);
// Here, we override the fusion operation to use a customized EVT fusion, in addition to the previous schedule overrides
ExampleRunner<
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::PersistentScheduler,
true> ws_cooperative_schedule_auto_stage_custom_evt_runner;
passed = ws_cooperative_schedule_auto_stage_custom_evt_runner.run(options, hw_info);
print_result("Cooperative warp-specialized TMA schedule using custom epilogue visitor tree with automatically-selected stage count", passed);
#endif

View File

@ -0,0 +1,34 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
cutlass_example_add_executable(
49_collective_builder
49_collective_builder.cu
)

View File

@ -34,7 +34,7 @@
The following example shows how to assemble a custom GEMM kernel that spells out the Collectives
directly instead of using a builder and, in the process, instance a more efficient Epilogue
(from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue.
(from `cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp`) instead of using the default epilogue.
The GemmUniversal API takes 3 main template arguments:
(1) the problem shape / extents
@ -65,7 +65,7 @@
#include "cute/tensor.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/epilogue.hpp"
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
@ -122,7 +122,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "50_hopper_gemm_with_vectorized_epilogue\n\n"
out << "50_hopper_gemm_with_epilogue_swizzle\n\n"
<< "Hopper GEMM Example with Epilogue Swizzle.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
@ -262,10 +262,10 @@ struct ExampleRunner {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;
stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
block_A.reset(M * K * L);
block_B.reset(K * N * L);
@ -286,11 +286,8 @@ struct ExampleRunner {
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
block_A.get(),
stride_A,
block_B.get(),
stride_B,
{block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}},
{block_A.get(), stride_A, block_B.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
@ -443,11 +440,11 @@ int main(int argc, char const **args) {
cute::SM90_TMA_LOAD,
cute::SM90_TMA_LOAD_MULTICAST>::type;
using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector<
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{}))
>());
using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector<
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{}))
>());
@ -494,14 +491,15 @@ int main(int argc, char const **args) {
Stride<_16,_1>>,
TileShapeS2R>;
using Epilogue = cutlass::epilogue::collective::Epilogue<
using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::Epilogue<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutD>,
cutlass::epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>,
SmemLayout,
Copy_Atom<DefaultCopy, ElementAcc>,
TiledCopyS2R,
Copy_Atom<DefaultCopy, ElementOutput>>;
Copy_Atom<DefaultCopy, ElementOutput>>>;
//
// Assembling the GemmKernel

View File

@ -0,0 +1,371 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of a GETT targeting Hopper tensor cores using the CUTLASS 3.x API.
CUTLASS has long provided implementations of Generalized Matrix times Matrix (GEMM) kernels.
However, a plethora of workloads compute on higher ranked tensors. Products of such tensors,
called tensor contractions, can be executed as multiple batched GEMMs, however, they can be
further accelerated with kernels that natively operate on these higher ranked tensors to
perform Generalized Tensor times Tensor contractions (GETT). CuTe's hierarchical layouts
and CUTLASS 3.0's unified micro-kernels make implementation of GETTs trivial. In this example,
we show how CUTLASS 3.0, CuTe, and Hopper's TMA feature together can accelerate GETTs while
making the process of authoring custom GETT kernels easier than ever before.
The modes of a tensor that participate in a GETT can be fundamentally grouped into four
semantic categories. The contraction modes (or K-modes) only appear in the A and B (left and right)
inputs but not in the C output tensor. Row modes (or M-modes) only appear in the left
input tensor (A) and the output tensor (C). Column modes (or N-modes) only appear in the
right (B) input tensor and the output tensor (C). Batch modes (or L-modes) appear in all
input and output tensors. If we fold the many modes of a tensor contraction into these four
categories, it would allow us to represent the input and output tensors as rank-3 "matrices"
that can be computed upon as if we were computing a batched GEMM!
This is exactly what CuTe's hierarchical layout representation allows us to do! Instead of having
simple integers as strides for these four modes, we can have nested strides for each of these
semantic categories that themselves have multiple modes within them -- multi-mode strides!
In CUTLASS 3.0, all one has to do to take advantage of this capability is to substitute the
required multi-mode strides instead of the default ones provided by gemm::detail::TagToStrideX.
In the following example, we illustrate how every Hopper GEMM in CUTLASS 3.0 is a GETT in disguise.
We begin by defining the four modes detailed above as Row, Col (column), Red (reduction), and
Bat (batch) strides, which we then nest for each of the in/out tensors to create our rank-3 stride
tuples. Note that although we do not define the problem shape type explicitely, it too remains a
rank-4 shape tuple just like any other batched GEMM, but instead with multi-mode shapes for each
of the four corresponding multi-modes within it. After this, the same CollectiveMma and
CollectiveBuilder we describe in examples 50 and 49 are used to create our kernel type. Nothing
else changes from a user's point of view. Note that multi-mode strides do not affect our
specializations in any way -- the lexical spelling of our kernels remains the same. The
only difference between a CUTLASS 3 batched GEMM and GETT are the instaced CuTe Layouts.
CollectiveBuilders rely on detecting the static-1 in the stride tuples to determine the major mode,
which is what the example demonstrates. However, it is possible to have all modes be dynamic as well
if the user assembles a CollectiveMma manually and ensures that the runtime strides are compatible
with the static micro-kernel of the collective (TiledMma, TiledCopy, and smem layouts). On the other
hand, a user can have more than one static stride too (which need not correspond to the major mode).
In particular, this example demonstrates a GETT where the 0th M-mode (M0) in A and the 0th K-mode (K0)
in B are major. All other combinations of major modes are supported, with the exception of mixed
K-major scenarios where both A and B are K-major (e.g. K0 is major in A but K1 is major in B).
NVIDIA Hopper architecture's TMA feature makes the predictaion required to implement these complicated
kernels trivial, as it is all handled by TMA itself without requiring any programmer effort.
Example executions, where the stride order defines the major-order (major on the left):
51_hopper_gett --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096
51_hopper_gett --modeC=l,m,n --modeA=m,l,k --modeB=k,n,l --extents=m:128,n:128,k:128,l:64
51_hopper_gett --modeC=m,a,b,p,q,n,l --modeA=m,l,b,k,a --modeB=k,n,p,q,l --extents=m:32,a:32,b:3,n:128,k:128,l:4,p:3,q:3
*/
#include "gett_kernel.cuh"
#include "thrust/host_vector.h"
#include "thrust/device_vector.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/gett_commandline.hpp"
#include "cutlass/util/reference/device/gett.hpp"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/print_error.hpp"
namespace example {
// Returns true if the left-most value in the tuple is statically known to be 1
template<class Stride>
constexpr bool
is_left_major() {
// Account for stride types with and without batch mode and batch modes with static zero stride
return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value;
}
// Same as cute::make_int_tuple but inserts a major stride (Int<1>) for the leftmost mode if required
template <int Rank, bool IsMajor, class Indexable>
static constexpr
auto
make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) {
static_assert(Rank > 1);
if constexpr (IsMajor) {
return cute::transform(cute::make_seq<Rank>{}, [&](auto i) {
if constexpr (i == 0) {
return cute::Int<1>{};
}
else {
return i < n ? t[i] : init_default;
}
});
}
else {
return cute::make_int_tuple<Rank>(t, n, init_default);
}
}
} // namespace example
//////////////////////////////////////////////////////////////////////////////
int
main(int argc, char const* argv[]) {
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
if (argc != 5) {
std::cout << "Number of command line args must be 4.\n";
cutlass::GettCommandLine::print_usage();
return 0;
}
//
// Define the stride types for A, B, C, and D
//
// Stride for A (left input). If reduction mode is major, same must be major in B
// For this example, M0 is major in A.
using RowModeStridesA = cute::Stride<cute::Int<1>, int64_t, int64_t, int64_t>;
using RedModeStridesA = cute::Stride<int64_t, int64_t, int64_t>;
using BatModeStridesA = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
// Stride for B (right input). If reduction mode is major, same must be major in A
// For this example, K0 is major in B.
using ColModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
using RedModeStridesB = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using BatModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
// Strides for output, which can all be dynamic.
using RowModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
using ColModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
using BatModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
// Assmble our rank-3 multi-mode strides for the in/out tensors
using StrideA = cute::Stride<RowModeStridesA, RedModeStridesA, BatModeStridesA>;
using StrideB = cute::Stride<ColModeStridesB, RedModeStridesB, BatModeStridesB>;
using StrideC = cute::Stride<RowModeStridesC, ColModeStridesC, BatModeStridesC>;
// Note: C and D share strides here for simplicity.
// In general, they need not have the same layout.
using StrideD = StrideC;
//
// Define element types for tensors and intermediate values
//
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementD = float;
using ElementAccumulator = float;
using ElementEpilogue = float;
// The following constexpr values set the max number of modes in each MNKL mode
constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes
constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes
constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes
constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes
static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{}));
static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{}));
static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{}));
static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{}));
static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{}));
// Parse command line to get modes, extents, and strides
cutlass::GettCommandLine cmd;
auto parsed_args = cmd.parse(argc, argv, true);
auto& m = parsed_args.M;
auto& ldAm = parsed_args.ldAm;
auto& ldCm = parsed_args.ldCm;
int rank_m = int(m.size());
auto& n = parsed_args.N;
auto& ldBn = parsed_args.ldBn;
auto& ldCn = parsed_args.ldCn;
int rank_n = int(n.size());
auto& k = parsed_args.K;
auto& ldAk = parsed_args.ldAk;
auto& ldBk = parsed_args.ldBk;
int rank_k = int(k.size());
auto& l = parsed_args.L;
auto& ldAl = parsed_args.ldAl;
auto& ldBl = parsed_args.ldBl;
auto& ldCl = parsed_args.ldCl;
int rank_l = int(l.size());
if ((rank_m > MaxRank_M) || (rank_n > MaxRank_N) || (rank_k > MaxRank_K) || (rank_l > MaxRank_L)) {
std::cerr << "ERROR: Input has more modes than statically configured.";
return 1;
}
// Check that the user input major stride match the static major strides.
if (example::is_left_major<RowModeStridesA>() && (ldAm[0] != 1)) {
std::cerr << "ERROR: A_M0 is expected to be major, but was not in the provided input!\n";
return 1;
}
if (example::is_left_major<RedModeStridesA>() && (ldAk[0] != 1)) {
std::cerr << "ERROR: A_K0 is expected to be major, but was not in the provided input!\n";
return 1;
}
if (example::is_left_major<ColModeStridesB>() && (ldBn[0] != 1)) {
std::cerr << "ERROR: B_N0 is expected to be major, but was not in the provided input!\n";
return 1;
}
if (example::is_left_major<RedModeStridesB>() && (ldBk[0] != 1)) {
std::cerr << "ERROR: B_K0 is expected to be major, but was not in the provided input!\n";
return 1;
}
// Convert to `cute::Tuple`s and set up arguments
auto M = make_int_tuple<MaxRank_M>(m.data(), rank_m, 1);
auto dAm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesA>()>(ldAm.data(), rank_m);
auto dCm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesC>()>(ldCm.data(), rank_m);
auto N = make_int_tuple<MaxRank_N>(n.data(), rank_n, 1);
auto dBn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesB>()>(ldBn.data(), rank_n);
auto dCn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesC>()>(ldCn.data(), rank_n);
auto K = make_int_tuple<MaxRank_K>(k.data(), rank_k, 1);
auto dAk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesA>()>(ldAk.data(), rank_k);
auto dBk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesB>()>(ldBk.data(), rank_k);
auto L = make_int_tuple<MaxRank_L>(l.data(), rank_l, 1);
auto dAl = make_int_tuple<MaxRank_L>(ldAl.data(), rank_l, 0);
auto dBl = make_int_tuple<MaxRank_L>(ldBl.data(), rank_l, 0);
auto dCl = make_int_tuple<MaxRank_L>(ldCl.data(), rank_l, 0);
// Concat tuples to turn it into rank-4 problem shape and rank-3 strides, just like GEMM
auto problem_shape = make_shape(M, N, K, L);
StrideA stride_A = make_stride(dAm, dAk, dAl);
StrideB stride_B = make_stride(dBn, dBk, dBl);
StrideC stride_C = make_stride(dCm, dCn, dCl);
StrideD stride_D = stride_C;
auto alpha = ElementEpilogue(1.0f);
auto beta = ElementEpilogue(1.0f);
//
// Allocate and init tensors
//
auto M_size = std::accumulate(std::begin(m), std::end(m), 1, std::multiplies<>{});
auto N_size = std::accumulate(std::begin(n), std::end(n), 1, std::multiplies<>{});
auto K_size = std::accumulate(std::begin(k), std::end(k), 1, std::multiplies<>{});
auto L_size = std::accumulate(std::begin(l), std::end(l), 1, std::multiplies<>{});
thrust::host_vector<ElementA> h_A(M_size * K_size * L_size);
thrust::host_vector<ElementB> h_B(N_size * K_size * L_size);
thrust::host_vector<ElementC> h_C(M_size * N_size * L_size);
thrust::host_vector<ElementD> h_D(M_size * N_size * L_size);
// Note: the cast to int here is to avoid false-negative ref-checks which can
// occur due to floating point arithmetic not being purely associative.
for (auto& a : h_A) a = ElementA(int(4*(rand() / double(RAND_MAX)) - 1));
for (auto& b : h_B) b = ElementB(int(4*(rand() / double(RAND_MAX)) - 1));
for (auto& c : h_C) c = ElementC(int(4*(rand() / double(RAND_MAX)) - 1));
for (auto& d : h_D) d = ElementD(-1);
thrust::device_vector<ElementA> d_A = h_A;
thrust::device_vector<ElementB> d_B = h_B;
thrust::device_vector<ElementC> d_C = h_C;
thrust::device_vector<ElementD> cutlass_result = h_D;
thrust::device_vector<ElementD> reference_result = h_D;
//
// Compute GETT
//
auto status = example::gett_kernel(
problem_shape,
d_A.data().get(), stride_A,
d_B.data().get(), stride_B,
ElementAccumulator{},
d_C.data().get(), stride_C,
cutlass_result.data().get(), stride_D,
alpha, beta);
if (cutlass::Status::kSuccess != status) {
std::cerr << "ERROR: GETT operator launch failed.\n";
return 1;
}
auto cuda_err = cudaDeviceSynchronize();
if (cudaSuccess != cuda_err) {
std::cerr << "ERROR: GETT operator execution failed. with error :";
std::cerr << cudaGetErrorString(cuda_err) << "\n";
return 1;
}
//
// Verify
//
cutlass::reference::device::gett(
problem_shape,
d_A.data().get(), stride_A,
d_B.data().get(), stride_B,
ElementAccumulator{},
d_C.data().get(), stride_C,
reference_result.data().get(), stride_D,
alpha, beta);
cuda_err = cudaDeviceSynchronize();
if (cudaSuccess != cuda_err) {
std::cerr << "ERROR: GETT reference execution failed. with error :";
std::cerr << cudaGetErrorString(cuda_err) << "\n";
return 1;
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(
reference_result.data().get(), cutlass_result.data().get(), cutlass_result.size());
if (passed) {
std::cout << "GETT verification passed.\n";
return 0;
}
else {
std::cerr << "ERROR: GETT verification failed! Printing detailed stats.\n";
h_D = reference_result;
thrust::host_vector<ElementD> h_cutlass_result = cutlass_result;
print_relative_error(h_cutlass_result.size(), h_cutlass_result.data(), h_D.data());
std::cout << "StrideA: "; print(stride_A); std::cout << '\n';
std::cout << "StrideB: "; print(stride_B); std::cout << '\n';
std::cout << "StrideC: "; print(stride_C); std::cout << '\n';
std::cout << "StrideD: "; print(stride_D); std::cout << '\n';
return 1;
}
#else
std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM90_SUPPORTED is defined.\n";
return 0;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
}

View File

@ -0,0 +1,32 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
51_hopper_gett
51_hopper_gett.cu
)

View File

@ -0,0 +1,137 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
namespace example {
//
// GETT entry point
//
template <
class ProblemShapeMNKL,
class ElementA,
class StrideA,
class ElementB,
class StrideB,
class ElementAccumulator,
class ElementC,
class StrideC,
class ElementD,
class StrideD,
class ElementEpilogue>
cutlass::Status
gett_kernel(
ProblemShapeMNKL problem_shape_mnkl,
ElementA const* ptr_A, StrideA stride_a_mkl,
ElementB const* ptr_B, StrideB stride_b_nkl,
ElementAccumulator _,
ElementC const* ptr_C, StrideC stride_c_mnl,
ElementD * ptr_D, StrideD stride_d_mnl,
ElementEpilogue alpha, ElementEpilogue beta,
cudaStream_t stream = 0) {
using namespace cute;
// TileShape -- GETT configuration
// Specify the number of elements to take from each mode
// BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...)
// Take 128 from m0, 128 from n0, 64 from k0
using TileShape = Shape<Shape<_128>, Shape<_128>, Shape<_64>>;
/* Other examples:
* Take 32 elements from m0 and 4 elements from m1
* Take 64 elements from n0 and 2 elements from n1
* Take 8 elements from k0 and 8 elements from k1
**/
// using TileShape = Shape<Shape<_32,_4>, Shape<_64,_2>, Shape<_8,_8>>;
using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination<
ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default,
cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
// No changes are required to the default epilogue
using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::DefaultEpilogue<
StrideC,
StrideD,
EpilogueThreadOp,
cutlass::gemm::EpilogueDefault>>;
// CollectiveMma for GETTs can be built using the CollectiveBuilders
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, StrideA, 128 / cutlass::sizeof_bits<ElementA>::value,
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
ElementAccumulator,
TileShape, Shape<_1,_2,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
// The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM
using GettKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShapeMNKL,
CollectiveMainloop,
CollectiveEpilogue>;
using GettOperator = cutlass::gemm::device::GemmUniversalAdapter<GettKernel>;
typename GettOperator::Arguments args {
cutlass::gemm::GemmUniversalMode::kBatched,
problem_shape_mnkl,
{ ptr_A, stride_a_mkl, ptr_B, stride_b_nkl },
{ {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl }
};
#if CUTLASS_DEBUG_TRACE_LEVEL > 0
print("Problem shape:");
print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n");
print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n");
print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n");
print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n");
print("TileSape:"); print(TileShape{}); print("\n");
#endif
GettOperator op;
return op(args, stream);
}
} // namespace example

View File

@ -0,0 +1,687 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of a Hopper gather+GEMM+scatter kernel fusion.
This example fuses gather before GEMM and scatter after GEMM into the same
GEMM kernel. Gather and scatter operation is controled by an index vector
to select rows or columns from A, B, C or D matrices.
Gather/scatter operations are always performed along a strided dimension
in order to preserve vectorized loads/stores. Thus the index vector is
applied to rows of row-major matrices and columns of column-major matrices.
Note that the index vector must contain integers in range [0,X) where
X is one of (M,N,K), depending on selected gather dimension. The problem
shape given to the GEMM kernel must consist of matrix sizes AFTER gather
and BEFORE scatter operations are applied.
*/
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <math.h>
#include <assert.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <iostream>
#include <random>
#include <numeric>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
#include "gather_gemm.hpp"
#include "gather_kernel.cuh"
#include "scatter_epilogue.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
using namespace cute;
namespace example {
// Command line options parsing
struct Options {
bool help = false;
cutlass::gemm::BatchedGemmCoord problem_size = {2048, 2048, 2048, 1};
int index_size = 1024;
int mode = 1; // N-mode gather/scatter by default
float alpha = 1.0f;
float beta = 1.0f;
bool reference_check = true;
int iterations = 20;
bool valid() const {
return problem_size.m() > 0
&& problem_size.n() > 0
&& problem_size.k() > 0
&& problem_size.batch() > 0
&& 0 <= mode && mode < 3
&& index_size <= problem_size.at(mode)
&& iterations > 0;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("batch_size", problem_size.batch());
cmd.get_cmd_line_argument("index_size", index_size);
char const modes[] = {'m', 'n', 'k'};
char mode_input = modes[mode];
cmd.get_cmd_line_argument("mode", mode_input);
mode = int(std::distance(std::begin(modes), std::find(std::begin(modes), std::end(modes), mode_input)));
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("check", reference_check, true);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out <<
"52_hopper_gather_scatter_fusion example\n"
"\n"
" This example uses the CUTLASS Library to fuse gather/scatter of input/output tensors with GEMM.\n"
" It validates and benchmarks the fused kernel against an unfused implementation that executes\n"
" gather+GEMM+scatter in sequence and writes intermediate (gathered) tensors to memory.\n"
" For the unfused implementation two GEMM kernels are considered: default one that uses the same\n"
" schedule and instruction set as the fused one, and an optimized one that utilizes advanced\n"
" features (such as TMA units) that cannot be used by the fused kernel due to hardware constraints."
"\n"
"Options:\n"
" --help If specified, displays this usage statement.\n"
" --m=<int> GEMM M dimension\n"
" --n=<int> GEMM N dimension\n"
" --k=<int> GEMM K dimension\n"
" --batch_size=<int> GEMM batch size\n"
" --index_size=<int> Size of N dimension gather/scatter index\n"
" --mode=<m,n,k> Gather mode (M, N, or K)\n"
" --alpha=<float> GEMM alpha parameter\n"
" --beta=<float> GEMM beta parameter\n"
" --iterations=<int> Number of profiling iterations to perform.\n"
"\n"
"Examples:\n"
"\n"
"$ ./examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion --m=1024 --n=2048 --k=1024 --mode=n --index_size=1024\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template<class ElementA, class LayoutA, class GatherA,
class ElementB, class LayoutB, class GatherB,
class ElementC, class LayoutC, class GatherC,
class ElementD, class LayoutD, class ScatterD,
class ElementAccumulator, class ElementComputeEpilogue>
struct ExampleRunner
{
// Useful aliases
// Alias to for the epilogue type that supports gather/scatter
using Epilogue = cutlass::epilogue::collective::EpilogueGatherScatter<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutD>,
cutlass::epilogue::thread::LinearCombination<
ElementD, 1,
ElementAccumulator, ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::Default,
cutlass::FloatRoundStyle::round_to_nearest, ElementC
>,
cutlass::gemm::EpilogueDefault,
GatherC,
ScatterD
>;
// Alias to for the mainloop type
using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, 128 / cutlass::sizeof_bits<ElementA>::value,
ElementB, LayoutB, 128 / cutlass::sizeof_bits<ElementB>::value,
ElementAccumulator,
Shape<_128,_128,_64>,
Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCount<5>,
cutlass::gemm::KernelMultistage
>::CollectiveOp;
using ProblemShape = Shape<int,int,int,int>;
using Kernel = cutlass::gemm::kernel::GemmGather<
ProblemShape,
Mainloop,
Epilogue,
GatherA,
GatherB
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<Kernel>;
using StrideA = typename Kernel::StrideA;
using StrideB = typename Kernel::StrideB;
using StrideC = typename Kernel::StrideC;
using StrideD = typename Kernel::StrideD;
static constexpr bool DoGatherA = not cutlass::platform::is_same<GatherA, NoGather>::value;
static constexpr bool DoGatherB = not cutlass::platform::is_same<GatherB, NoGather>::value;
static constexpr bool DoGatherC = not cutlass::platform::is_same<GatherC, NoGather>::value;
static constexpr bool DoScatterD = not cutlass::platform::is_same<ScatterD, NoGather>::value;
static constexpr bool GatherAonM = DoGatherA && cutlass::platform::is_same<LayoutA,cutlass::layout::RowMajor>::value;
static constexpr bool GatherAonK = DoGatherA && cutlass::platform::is_same<LayoutA,cutlass::layout::ColumnMajor>::value;
static constexpr bool GatherBonN = DoGatherB && cutlass::platform::is_same<LayoutB,cutlass::layout::ColumnMajor>::value;
static constexpr bool GatherBonK = DoGatherB && cutlass::platform::is_same<LayoutB,cutlass::layout::RowMajor>::value;
static constexpr bool GatherConM = DoGatherC && cutlass::platform::is_same<LayoutC,cutlass::layout::RowMajor>::value;
static constexpr bool GatherConN = DoGatherC && cutlass::platform::is_same<LayoutC,cutlass::layout::ColumnMajor>::value;
static constexpr bool ScatterDonM = DoScatterD && cutlass::platform::is_same<LayoutD,cutlass::layout::RowMajor>::value;
static constexpr bool ScatterDonN = DoScatterD && cutlass::platform::is_same<LayoutD,cutlass::layout::ColumnMajor>::value;
static constexpr bool GatherModeM = GatherAonM || GatherConM || ScatterDonM;
static constexpr bool GatherModeN = GatherBonN || GatherConN || ScatterDonN;
static constexpr bool GatherModeK = GatherAonK || GatherBonK;
static_assert( GatherModeM && !GatherModeN && !GatherModeK ||
!GatherModeM && GatherModeN && !GatherModeK ||
!GatherModeM && !GatherModeN && GatherModeK,
"Only one gather mode (M, N or K) is supported by example runner");
// Construct a reference (non-gather) GEMM kernel type
using MainloopRef = Mainloop;
using EpilogueRef = typename cutlass::epilogue::collective::DefaultEpilogue<
StrideC, StrideD,
typename Epilogue::ThreadEpilogueOp,
typename Epilogue::EpilogueSchedule
>;
using KernelRef = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
MainloopRef,
EpilogueRef
>;
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<KernelRef>;
// Construct an optimized reference GEMM kernel type (using TMA)
using EpilogueOpt = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128,_128,_64>,
Shape<_2,_2,_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementComputeEpilogue,
ElementC, LayoutC, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementD, LayoutD, 128 / cutlass::sizeof_bits<ElementD>::value,
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using MainloopOpt = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, 128 / cutlass::sizeof_bits<ElementA>::value,
ElementB, LayoutB, 128 / cutlass::sizeof_bits<ElementB>::value,
ElementAccumulator,
Shape<_128,_128,_64>,
Shape<_2,_2,_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename EpilogueOpt::SharedStorage)>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using KernelOpt = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
MainloopOpt,
EpilogueOpt
>;
using GemmOpt = cutlass::gemm::device::GemmUniversalAdapter<KernelOpt>;
// Data members
cutlass::gemm::BatchedGemmCoord problem_size_orig;
cutlass::gemm::BatchedGemmCoord problem_size;
ProblemShape problem_shape_orig;
ProblemShape problem_shape;
cutlass::KernelHardwareInfo hw_info;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
StrideA stride_A_orig;
StrideB stride_B_orig;
StrideC stride_C_orig;
StrideD stride_D_orig;
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
cutlass::device_memory::allocation<ElementA> tensor_a;
cutlass::device_memory::allocation<ElementB> tensor_b;
cutlass::device_memory::allocation<ElementC> tensor_c;
cutlass::device_memory::allocation<ElementD> tensor_d;
cutlass::device_memory::allocation<int> gather_indices;
cutlass::device_memory::allocation<ElementA> tensor_a_gathered;
cutlass::device_memory::allocation<ElementB> tensor_b_gathered;
cutlass::device_memory::allocation<ElementC> tensor_c_gathered;
cutlass::device_memory::allocation<ElementD> tensor_d_gathered;
cutlass::device_memory::allocation<ElementD> tensor_d_reference;
cutlass::gemm::GemmUniversalMode gemm_mode;
Gemm gemm;
typename Gemm::Arguments arguments;
cutlass::device_memory::allocation<uint8_t> workspace;
GemmRef gemm_ref;
typename GemmRef::Arguments arguments_ref;
cutlass::device_memory::allocation<uint8_t> workspace_ref;
GemmOpt gemm_opt;
typename GemmOpt::Arguments arguments_opt;
cutlass::device_memory::allocation<uint8_t> workspace_opt;
ExampleRunner(Options const &options, cutlass::KernelHardwareInfo const &hw_info)
: problem_size_orig(options.problem_size),
problem_size(GatherModeM ? options.index_size : problem_size_orig.m(),
GatherModeN ? options.index_size : problem_size_orig.n(),
GatherModeK ? options.index_size : problem_size_orig.k(),
problem_size_orig.batch()),
problem_shape_orig(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()),
problem_shape(problem_size.m(), problem_size.n(), problem_size.k(), problem_size.batch()),
hw_info(hw_info),
alpha(options.alpha),
beta(options.beta),
stride_A_orig(cutlass::make_cute_packed_stride(
StrideA{}, make_shape(problem_size_orig.m(), problem_size_orig.k(), problem_size_orig.batch()))),
stride_B_orig(cutlass::make_cute_packed_stride(
StrideB{}, make_shape(problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()))),
stride_C_orig(cutlass::make_cute_packed_stride(
StrideC{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))),
stride_D_orig(cutlass::make_cute_packed_stride(
StrideD{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))),
stride_A(cutlass::make_cute_packed_stride(
StrideA{}, make_shape(problem_size.m(), problem_size.k(), problem_size.batch()))),
stride_B(cutlass::make_cute_packed_stride(
StrideB{}, make_shape(problem_size.n(), problem_size.k(), problem_size.batch()))),
stride_C(cutlass::make_cute_packed_stride(
StrideC{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))),
stride_D(cutlass::make_cute_packed_stride(
StrideD{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))),
tensor_a(problem_size_orig.m() * problem_size_orig.k() * problem_size_orig.batch()),
tensor_b(problem_size_orig.k() * problem_size_orig.n() * problem_size_orig.batch()),
tensor_c(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()),
tensor_d(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()),
gather_indices(options.index_size),
tensor_a_gathered(problem_size.m() * problem_size.k() * problem_size_orig.batch()),
tensor_b_gathered(problem_size.k() * problem_size.n() * problem_size_orig.batch()),
tensor_c_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()),
tensor_d_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()),
tensor_d_reference(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()),
gemm_mode(problem_size.batch() > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm),
gemm(),
// When constructing arguments for gather/scatter gemm, we must pass stride arguments
// made for the original (non-gathered) problem size, because they are used to access
// tensors of the original shape. However we still use the reduced (gathered) problem
// shape since it corresponds to the logical indexing in reduced size GEMM.
arguments{
gemm_mode,
problem_shape,
{
tensor_a.get(),
stride_A_orig,
tensor_b.get(),
stride_B_orig
},
{
{ alpha, beta },
tensor_c.get(), stride_C_orig,
tensor_d.get(), stride_D_orig,
typename Epilogue::GatherC {gather_indices.get()},
typename Epilogue::ScatterD{gather_indices.get()}
},
hw_info,
typename Kernel::GatherA{gather_indices.get()},
typename Kernel::GatherB{gather_indices.get()}
},
workspace(Gemm::get_workspace_size(arguments)),
gemm_ref(),
arguments_ref{
gemm_mode,
problem_shape,
{
DoGatherA ? tensor_a_gathered.get() : tensor_a.get(),
stride_A,
DoGatherB ? tensor_b_gathered.get() : tensor_b.get(),
stride_B
},
{
{ alpha, beta },
DoGatherC ? tensor_c_gathered.get() : tensor_c.get(),
stride_C,
DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(),
stride_D
},
hw_info
},
workspace_ref(GemmRef::get_workspace_size(arguments_ref)),
gemm_opt(),
arguments_opt{
gemm_mode,
problem_shape,
{
DoGatherA ? tensor_a_gathered.get() : tensor_a.get(),
stride_A,
DoGatherB ? tensor_b_gathered.get() : tensor_b.get(),
stride_B
},
{
{ alpha, beta },
DoGatherC ? tensor_c_gathered.get() : tensor_c.get(),
stride_C,
DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(),
stride_D
},
hw_info
},
workspace_opt(GemmOpt::get_workspace_size(arguments_opt))
{
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::device::BlockFillRandomUniform(tensor_a.get(), tensor_a.size(), 1, ElementA(7), ElementA(-8), 0);
cutlass::reference::device::BlockFillRandomUniform(tensor_b.get(), tensor_b.size(), 1, ElementB(7), ElementB(-8), 0);
cutlass::reference::device::BlockFillRandomUniform(tensor_c.get(), tensor_c.size(), 1, ElementC(7), ElementC(-8), 0);
cutlass::reference::device::BlockFillSequential(tensor_d.get(), tensor_d.size(), ElementD(0), ElementD(0));
// <- Fill gather_indices with unique random integers in range [0,n)
int index_range = GatherModeM ? problem_size_orig.m() : (GatherModeN ? problem_size_orig.n() : problem_size_orig.k());
std::vector<int> indices(index_range);
std::iota(indices.begin(), indices.end(), 0);
{ // std::random_shuffle was deprecated in C++14 and removed in C++17
std::random_device make_seed;
std::mt19937 source_of_randomness(make_seed());
std::shuffle(indices.begin(), indices.end(), source_of_randomness);
}
gather_indices.copy_from_host(indices.data());
auto const gemm_init = [](auto & gemm, auto const & arguments, auto & workspace)
{
cutlass::Status status = gemm.can_implement(arguments);
CUTLASS_CHECK(status);
status = gemm.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
};
gemm_init(gemm, arguments, workspace );
gemm_init(gemm_ref, arguments_ref, workspace_ref);
gemm_init(gemm_opt, arguments_opt, workspace_opt);
}
void debug_output(std::ostream & os)
{
auto print_tensor = [](std::ostream &os, char const * name, auto const & data, auto shape, auto stride)
{
std::vector<remove_cvref_t<decltype(*data.get())>> h_data(data.size());
data.copy_to_host(h_data.data());
Tensor t = make_tensor(h_data.data(), shape, stride);
os << "\n" << name << ": " << std::setw(4) << t << std::endl;
};
{
auto [M,N,K,L] = problem_shape_orig;
print_tensor(os, "A", tensor_a, make_shape(M,K,L), stride_A_orig);
print_tensor(os, "B", tensor_b, make_shape(N,K,L), stride_B_orig);
print_tensor(os, "C", tensor_c, make_shape(M,N,L), stride_C_orig);
print_tensor(os, "D", tensor_d, make_shape(M,N,L), stride_D_orig);
print_tensor(os, "D reference", tensor_d_reference, make_shape(M,N,L), stride_D_orig);
print_tensor(os, "indices", gather_indices, make_shape(gather_indices.size()), make_stride(_1{}));
}
}
template<class Gemm2>
static void run_gemm(Gemm2 &gemm)
{
cutlass::Status status = gemm.run();
CUTLASS_CHECK(status);
}
template<class Gemm2>
void run_reference(Gemm2 &gemm)
{
// Convenience wrapper around calls to separate gather/scatter kernels
auto run_gather = [this](auto call, auto const & input, auto & output, auto gather_func, auto batch_size, auto stride)
{
[[maybe_unused]] auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
call(input.get(),
output.get(),
gather_func,
batch_size,
static_cast<int>(input.size() / batch_size),
static_cast<int>(output.size() / batch_size),
static_cast<int>(get<I>(stride)),
hw_info);
};
// Forward calls via lambda to avoid specifying template arguments
auto gather_call = [](auto&&... args){ gather(static_cast<decltype(args)&&>(args)...); };
// MSVC doesn't count use inside a false "if constexpr" branch.
[[maybe_unused]] auto scatter_call = [](auto&&... args){ scatter(static_cast<decltype(args)&&>(args)...); };
if constexpr (DoGatherA) {
run_gather(gather_call, tensor_a, tensor_a_gathered, arguments.gather_A, problem_size.batch(), stride_A);
}
if constexpr (DoGatherB) {
run_gather(gather_call, tensor_b, tensor_b_gathered, arguments.gather_B, problem_size.batch(), stride_B);
}
if constexpr (DoGatherC) {
if (beta != ElementComputeEpilogue(0)) {
run_gather(gather_call, tensor_c, tensor_c_gathered, arguments.epilogue.gather_C, problem_size.batch(), stride_C);
}
}
run_gemm(gemm);
if constexpr (DoScatterD) {
run_gather(scatter_call, tensor_d_gathered, tensor_d_reference, arguments.epilogue.scatter_D, problem_size.batch(), stride_D);
}
}
bool verify()
{
run_gemm(gemm);
run_reference(gemm_ref);
cudaDeviceSynchronize();
return cutlass::reference::device::BlockCompareEqual(tensor_d.get(), tensor_d_reference.get(), tensor_d.size());
}
bool run(Options const &options)
{
if (options.reference_check) {
if (!verify()) {
std::cout << "Failed validation" << std::endl;
#if 1
debug_output(std::cout);
#endif
return false;
}
else {
std::cout << "Passed validation" << std::endl;
}
}
//
// Run profiling loop
//
auto const benchmark = [&](auto name, auto func)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
func();
}
timer.stop();
double runtime = timer.elapsed_millis() / double(options.iterations);
double gflops = 2 * double(problem_size.product()) / 1e6 / runtime; // Two flops per multiply-add
std::cout << name << ":\n";
std::cout << " Runtime: " << runtime << " ms\n";
std::cout << " GFLOPs: " << gflops << "\n";
};
benchmark("Fused", [&](){ run_gemm(gemm); });
benchmark("Unfused default", [&](){ run_reference(gemm_ref); });
benchmark("Unfused optimized", [&](){ run_reference(gemm_opt); });
return true;
}
};
} // namespace example
int main(int argc, const char ** argv) {
bool notSupported = false;
// CUDA 12 minimum required
if (__CUDACC_VER_MAJOR__ < 12) {
std::cerr << "This example requires CUDA Toolkit version 12 or later.\n";
notSupported = true;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (props.major < 9) {
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
notSupported = true;
}
if (notSupported) {
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
}
example::Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << "\n";
return EXIT_SUCCESS;
}
if (!options.valid()) {
std::cerr << "Invalid arguments." << "\n";
return EXIT_FAILURE;
}
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
bool result = true;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
switch (options.mode) {
using namespace example;
case 0: {
std::cout << "Gather A,C + scatter D on M mode:" << std::endl;
using Runner = ExampleRunner<
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // A
cutlass::half_t, cutlass::layout::ColumnMajor, NoGather, // B
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // C
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // D
float, float>;
result &= Runner(options, hw_info).run(options);
break;
}
case 1: {
std::cout << "Gather B,C + scatter D on N mode:" << std::endl;
using Runner = ExampleRunner<
cutlass::half_t, cutlass::layout::RowMajor, NoGather, // A
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // B
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // C
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // D
float, float>;
result &= Runner(options, hw_info).run(options);
break;
}
case 2: {
std::cout << "Gather A,B on K mode:" << std::endl;
using Runner = ExampleRunner<
cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather<int>, // A
cutlass::half_t, cutlass::layout::RowMajor, IndexedGather<int>, // B
cutlass::half_t, cutlass::layout::RowMajor, NoGather, // C
cutlass::half_t, cutlass::layout::RowMajor, NoGather, // D
float, float>;
result &= Runner(options, hw_info).run(options);
break;
}
}
#endif
return result ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@ -0,0 +1,32 @@
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
52_hopper_gather_scatter_fusion
52_hopper_gather_scatter_fusion.cu
)

View File

@ -0,0 +1,266 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/tensor.hpp"
#include "gather_tensor.hpp"
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveEpilogue_,
class GatherA_,
class GatherB_,
class TileScheduler_ = void
>
class GemmGather
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
using TileSchedulerTag = TileScheduler_;
using TileScheduler = TileScheduler_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(std::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
"Mainloop and epilogue do not agree on accumulator value type.");
using GatherA = GatherA_;
using GatherB = GatherB_;
static constexpr int SharedStorageSize = static_cast<int>(cute::max(
sizeof(typename CollectiveMainloop::SharedStorage),
sizeof(typename CollectiveEpilogue::SharedStorage)));
static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{});
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
GatherA gather_A{};
GatherB gather_B{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode;
ProblemShape problem_shape;
MainloopParams mainloop;
EpilogueParams epilogue;
GatherA gather_A{};
GatherB gather_B{};
};
//
// Methods
//
// Convert to underlying arguments.
static
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
(void) workspace;
return {
args.mode,
args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
args.gather_A,
args.gather_B
};
}
static
Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return Status::kSuccess;
}
static
bool
can_implement(Arguments const& args) {
return args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
}
static
int
get_workspace_size(Arguments const& args) {
return 0;
}
static constexpr
dim3
get_grid_shape(Params const& params) {
int batch_count = 1;
if constexpr (rank(ProblemShape{}) == 4) {
batch_count = cute::size<3>(params.problem_shape);
}
return dim3(
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
batch_count
);
}
static constexpr
dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
using namespace cute;
using X = Underscore;
// Preconditions
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
auto M = get<0>(problem_shape_MNKL);
auto N = get<1>(problem_shape_MNKL);
auto K = get<2>(problem_shape_MNKL);
auto L = get<3>(problem_shape_MNKL);
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Get the appropriate blocks for this thread block -- potential for thread block locality
int thread_idx = int(threadIdx.x);
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
auto [m_coord, n_coord, l_coord] = blockIdx;
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l)
// Represent the full tensors
Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l)
Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l)
// Get batch slice
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)
// Slice to get the tiles this thread block is responsible for
Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
// Compute tile residues for predication
auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
// Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
TiledMma tiled_mma;
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
clear(accumulators);
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
int k_tile_count = size<2>(gA);
// Perform the collective scoped MMA
CollectiveMainloop collective_mma;
collective_mma(
accumulators,
gA,
gB,
accumulators,
k_tile_iter, k_tile_count,
residue_mnk,
thread_idx,
smem_buf
);
// Epilogue and write to gD
CollectiveEpilogue epilogue{params.epilogue};
epilogue(
problem_shape_MNKL,
blk_shape,
blk_coord_mnkl,
accumulators,
tiled_mma,
residue_mnk,
thread_idx,
smem_buf
);
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

Some files were not shown because too many files have changed in this diff Show More