Compare commits
137 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 319a389f42 | |||
| 71def2f084 | |||
| 70f3ba57f5 | |||
| dd77fadc70 | |||
| be4578d517 | |||
| d7b499deff | |||
| 310ed81ac3 | |||
| 4c0d6e1eb4 | |||
| 167ac54c65 | |||
| 12f4108ac2 | |||
| dd571f0edb | |||
| 6d0d265047 | |||
| f11fa975a5 | |||
| 0e71d9b450 | |||
| eb0d4c9213 | |||
| bc45e2c023 | |||
| 095cbba57c | |||
| 8f1fe7a132 | |||
| 3ab1eacf09 | |||
| cd39c75e25 | |||
| b2e1e97cb1 | |||
| 96a11a1ef3 | |||
| e96f00586c | |||
| 3cfa5db2a2 | |||
| 1db6971a8d | |||
| b954127297 | |||
| d0d941efc7 | |||
| 8a951b2940 | |||
| 1e4703cbab | |||
| c3353add63 | |||
| ac8825b941 | |||
| 8fd94806e5 | |||
| d7c9cbf0b9 | |||
| c2ee13a0fe | |||
| f78994bb40 | |||
| dceabd4c5a | |||
| 86fa1dc30b | |||
| 288af365db | |||
| 0dc3ba60b3 | |||
| ec4f7e5194 | |||
| 4e666e1dfd | |||
| 3799e12f25 | |||
| fc3bc85db8 | |||
| 49c0a58d50 | |||
| 5fe09c2d67 | |||
| 6b69c79ac3 | |||
| 62e438f450 | |||
| 808c25337a | |||
| 6fc5008803 | |||
| a3bcc6981d | |||
| 3b28642801 | |||
| 538592dea4 | |||
| 2e07c4cc2f | |||
| 9ac255863f | |||
| 59e2aa505a | |||
| 4e8af93da1 | |||
| 6c2f8f2fb8 | |||
| 598e35401c | |||
| a01feb93d9 | |||
| d36f331b44 | |||
| 69abafb85a | |||
| 68a078fbbf | |||
| 10709dbb64 | |||
| 1227351079 | |||
| a77c658439 | |||
| 4516b833ce | |||
| 64dd1e1915 | |||
| 1ac4559d12 | |||
| e5d51840e8 | |||
| 6c29fe20ba | |||
| e3c56b0d6b | |||
| 4647c57243 | |||
| 856d4db3fb | |||
| 6a1064093f | |||
| c5f1ef4dff | |||
| 47ebfccbec | |||
| ad9486684f | |||
| 1d8372a8e2 | |||
| 9cb7d63424 | |||
| da2f110906 | |||
| b68113f5be | |||
| a68d7cd6f1 | |||
| 38e8b29f56 | |||
| ee7349c94f | |||
| 8cdd4293d4 | |||
| f58b843951 | |||
| 5fc142296f | |||
| 233d69aa6d | |||
| 9840d25269 | |||
| b878c96421 | |||
| 8f8a80cad5 | |||
| a8f6f8eb07 | |||
| f4b0a33633 | |||
| 7c783adf53 | |||
| 4000df9567 | |||
| bb35a3ba6f | |||
| 7ec3a87f22 | |||
| 0b74c8f473 | |||
| 83036ed646 | |||
| b7e43f5eb9 | |||
| 5c62d892fa | |||
| 41a31b404b | |||
| 7320aee17d | |||
| 2142a05d9d | |||
| c77a524459 | |||
| fac6680f31 | |||
| 08993707da | |||
| c805593ebe | |||
| 26556d7206 | |||
| 4839b6cb61 | |||
| d97214987a | |||
| b0bbc6d548 | |||
| 7074047a54 | |||
| 75a4737cfe | |||
| 8a3e4b8d02 | |||
| 6a6b4028bd | |||
| 92393b2676 | |||
| 50bf00e5f2 | |||
| 4cd004ead1 | |||
| 6c4539e372 | |||
| a3639ab1a0 | |||
| 169181f30f | |||
| 0f1056390d | |||
| 34a42e5620 | |||
| 8f09b82b12 | |||
| 200a5a5146 | |||
| 746b7b3247 | |||
| abdf16a4d9 | |||
| 0e13748649 | |||
| ccb697bac7 | |||
| e6bcdc60cf | |||
| 6615010cd0 | |||
| c2b80ad4e4 | |||
| 37a8f9e598 | |||
| c53f3339bb | |||
| 4dac7490e6 | |||
| fd7e058d0c |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# PyCache files
|
||||
__pycache__/
|
||||
173
CHANGELOG.md
173
CHANGELOG.md
@ -1,6 +1,135 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
# CUTLASS 2.x
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
* [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities
|
||||
* [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size
|
||||
* [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler
|
||||
* [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
* Supported types: f32, cf32, f64, cf64
|
||||
* [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py)
|
||||
* [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/symm_operation.py)
|
||||
* [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/trmm_operation.py)
|
||||
* [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h)
|
||||
* [CUTLASS Python](/example/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
* [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters
|
||||
* [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
|
||||
## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19)
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
* 45+ TFLOPs on NVIDIA A100
|
||||
* [GEMM SDK example](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real)
|
||||
* [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu) (complex)
|
||||
* [Implicit GEMM Convolution SDK example](/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu)
|
||||
* **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu
|
||||
* [Conv Fprop SDK example](/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu)
|
||||
* [Conv WGrad SDK example](/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu)
|
||||
* [cutlass::conv::device::ImplicitGemmConvolutionFusion](/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h)
|
||||
* **Grouped GEMM:** similar to batched GEMM with distinct problem size per group
|
||||
* [SDK example](/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM
|
||||
* [cutlass::gemm::device::GemmGrouped](/include/cutlass/gemm/device/gemm_grouped.h)
|
||||
* [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue.
|
||||
* Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates from the community (thanks!)
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
* Maxwell and Pascal GPU architectures
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
|
||||
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
|
||||
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
|
||||
* [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196)
|
||||
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
|
||||
* Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/)
|
||||
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
|
||||
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
|
||||
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.6.1](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.1) (2021-09-03)
|
||||
* Arbitrary padding and striding for CUTLASS Strided DGRAD Convolution operator (Analytic Iterators)
|
||||
* Tuning for GEMMs fused with partial reductions
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h)
|
||||
* Fused operators with GEMM and Convolution
|
||||
* [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* [Fused partial reduction in epilogue](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu)
|
||||
* 64b tensor strides and leading dimensions support for GEMMs
|
||||
* Affine rank=2 matrix layouts
|
||||
* Row stride and column stride for matrices using [cutlass::layout::AffineRank2](/include/cutlass/layout/matrix.h)
|
||||
* Support [FP64 tensor core](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM.
|
||||
* [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation
|
||||
* [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation
|
||||
* Accelerates over previous implementation by cutting down redundant math by 4x
|
||||
* Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface
|
||||
* Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores)
|
||||
* Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h)
|
||||
* SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_gemm/quaternion_conv.cu)
|
||||
* [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu)
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Updated minimum CUDA Toolkit requirement to 10.2
|
||||
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26)
|
||||
* Tensor reductions
|
||||
* _m_-to-_n_ reductions of tensors with affine layout
|
||||
* [Specializations](/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension
|
||||
* [Specializations](/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension
|
||||
* Custom reduction functors such as `cutlass::logical_and`
|
||||
* Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31)
|
||||
* Optimizations for 3-D convolution
|
||||
* [Optimized tile iterators](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution
|
||||
* Full coverage of [forward](test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution
|
||||
* [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md)
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
|
||||
## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19)
|
||||
* Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs
|
||||
* Operators: forward (Fprop), backward data gradient (Dgrad), and backward weight gradient (Wgrad) convolution
|
||||
* Data type: FP32, complex<FP32>, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32
|
||||
* Spatial dimensions: 1-D, 2-D, and 3-D
|
||||
* Layout: NHWC, NCxHWx
|
||||
* Implicit GEMM convolution components:
|
||||
* Global memory iterators supporting Fprop, Dgrad, and Wgrad
|
||||
* `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture
|
||||
* `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures
|
||||
* [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation
|
||||
|
||||
## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* [Sparse Tensor Core GEMM kernels](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu):
|
||||
* Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Fast SGEMM targeting GeForce RTX 30-series CUDA Cores
|
||||
* Minor Features:
|
||||
* [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h)
|
||||
* Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code
|
||||
* [Floating-point constants](/include/cutlass/constants.h)
|
||||
* NVIDIA Ampere GPU Architecture examples and documentation:
|
||||
* [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and
|
||||
* [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu)
|
||||
* Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue)
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
@ -101,27 +230,33 @@
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
388
CMakeLists.txt
388
CMakeLists.txt
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR)
|
||||
@ -32,9 +38,15 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.9.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 10.2 or higher, and strongly recommends CUDA 11.0 or higher.")
|
||||
elseif (CUDA_VERSION VERSION_LESS 11.0)
|
||||
message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.0 or higher.")
|
||||
endif()
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
#
|
||||
@ -61,17 +73,23 @@ set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
|
||||
|
||||
if(CUTLASS_ENABLE_HEADERS_ONLY)
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
set(CUTLASS_ENABLE_LIBRARY_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
set(CUTLASS_ENABLE_LIBRARY_INIT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.")
|
||||
|
||||
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library")
|
||||
set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}})
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
@ -101,6 +119,9 @@ endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
@ -109,10 +130,6 @@ if (POLICY CMP0076)
|
||||
cmake_policy(SET CMP0076 NEW)
|
||||
endif()
|
||||
|
||||
if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
|
||||
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
|
||||
endif()
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
|
||||
@ -132,7 +149,12 @@ if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
|
||||
endif()
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries")
|
||||
if (DEFINED CMAKE_DEBUG_POSTFIX)
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX_INIT ${CMAKE_DEBUG_POSTFIX})
|
||||
else()
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX_INIT .debug)
|
||||
endif()
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ${CUTLASS_LIBRARY_DEBUG_POSTFIX_INIT} CACHE STRING "Default postfix value for debug libraries")
|
||||
|
||||
if(WIN32)
|
||||
# On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
|
||||
@ -154,6 +176,11 @@ if (${CUTLASS_NVCC_VERBOSE})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS NAMESPACE
|
||||
#
|
||||
set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS")
|
||||
|
||||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
@ -164,12 +191,28 @@ set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.
|
||||
#
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
|
||||
set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.")
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
|
||||
set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")
|
||||
|
||||
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
|
||||
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
||||
message(STATUS "Enable caching of reference results in conv unit tests")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED ON CACHE BOOL "Enable/Disable rigorous conv problem sizes in conv unit tests")
|
||||
|
||||
if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
|
||||
message(STATUS "Enable rigorous conv problem sizes in conv unit tests")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
@ -181,6 +224,10 @@ else()
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
# Trace levels for debugging
|
||||
set(CUTLASS_DEBUG_TRACE_LEVEL "0" CACHE STRING "Level of debug tracing to perform.")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_TRACE_LEVEL})
|
||||
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||
|
||||
@ -219,7 +266,7 @@ if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
|
||||
endif()
|
||||
|
||||
@ -241,6 +288,17 @@ if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo)
|
||||
endif()
|
||||
|
||||
#Report CUDA build flags
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if(CUTLASS_CUDA_CLANG_FLAGS)
|
||||
message(STATUS "Using CLANG flags: ${CUTLASS_CUDA_CLANG_FLAGS}")
|
||||
endif()
|
||||
else()
|
||||
if(CUTLASS_CUDA_NVCC_FLAGS)
|
||||
message(STATUS "Using NVCC flags: ${CUTLASS_CUDA_NVCC_FLAGS}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
|
||||
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
|
||||
@ -250,7 +308,14 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
|
||||
# There are numerous Clang versions that can work with each CUDA toolkit and the
|
||||
# the checks are not very useful so we are turning them off and using testing to
|
||||
# ensure the various combinations work properly.
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__NV_NO_HOST_COMPILER_CHECK=1)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unknown-cuda-version)
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
|
||||
@ -269,18 +334,33 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
link_libraries(nvidia::cudart)
|
||||
endif()
|
||||
|
||||
# Support for 128-bit integers if using NVIDIA C++ compiler
|
||||
if (${CMAKE_CXX_COMPILER_ID} MATCHES "PGI" OR ${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Mint128 ")
|
||||
endif()
|
||||
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
|
||||
# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this
|
||||
# property for CMake 3.18+, so we request the NEW behavior for correct compatibility.
|
||||
# https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104
|
||||
cmake_policy(SET CMP0104 NEW)
|
||||
endif()
|
||||
|
||||
function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
|
||||
set(NVCC_FLAGS)
|
||||
set(CLANG_FLAGS)
|
||||
set(__CMAKE_CUDA_ARCHS)
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
|
||||
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
|
||||
set(CODES)
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
list(APPEND CODES sm_${ARCH})
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real)
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
list(APPEND CODES compute_${ARCH})
|
||||
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual)
|
||||
endif()
|
||||
list(JOIN CODES "," CODES_STR)
|
||||
list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}])
|
||||
@ -292,6 +372,8 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:${CLANG_FLAGS}>
|
||||
)
|
||||
elseif(CMAKE_VERSION GREATER_EQUAL 3.18)
|
||||
set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS})
|
||||
else()
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
@ -302,22 +384,39 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
|
||||
endfunction()
|
||||
|
||||
# Cache the flags so they are available when the function below is called anywhere globally.
|
||||
|
||||
set(__CUTLASS_CUDA_FLAGS ${CUTLASS_CUDA_FLAGS} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_CLANG_FLAGS ${CUTLASS_CUDA_CLANG_FLAGS} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_CLANG_FLAGS_RELEASE ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_CLANG_FLAGS_DEBUG ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_NVCC_FLAGS ${CUTLASS_CUDA_NVCC_FLAGS} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_NVCC_FLAGS_RELEASE ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO} CACHE INTERNAL "")
|
||||
set(__CUTLASS_CUDA_NVCC_FLAGS_DEBUG ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG} CACHE INTERNAL "")
|
||||
|
||||
function(cutlass_apply_standard_compile_options TARGET)
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CUDA_COMPILE_LANGUAGE CXX)
|
||||
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_CLANG_FLAGS})
|
||||
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG})
|
||||
set(_FLAGS ${__CUTLASS_CUDA_FLAGS} ${__CUTLASS_CUDA_CLANG_FLAGS})
|
||||
set(_FLAGS_RELEASE ${__CUTLASS_CUDA_FLAGS_RELEASE} ${__CUTLASS_CUDA_CLANG_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${__CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${__CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${__CUTLASS_CUDA_FLAGS_DEBUG} ${__CUTLASS_CUDA_CLANG_FLAGS_DEBUG})
|
||||
else()
|
||||
set(CUDA_COMPILE_LANGUAGE CUDA)
|
||||
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_NVCC_FLAGS})
|
||||
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
|
||||
set(_FLAGS ${__CUTLASS_CUDA_FLAGS} ${__CUTLASS_CUDA_NVCC_FLAGS})
|
||||
set(_FLAGS_RELEASE ${__CUTLASS_CUDA_FLAGS_RELEASE} ${__CUTLASS_CUDA_NVCC_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${__CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${__CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${__CUTLASS_CUDA_FLAGS_DEBUG} ${__CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE CUTLASS)
|
||||
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
@ -352,7 +451,7 @@ set_target_properties(CUTLASS PROPERTIES EXPORT_NAME cutlass)
|
||||
|
||||
set(CUTLASS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
|
||||
|
||||
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library/)
|
||||
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library CACHE INTERNAL "Location of generator scripts")
|
||||
|
||||
# The following utility directory is needed even if the tools build is disabled, so it exists here.
|
||||
set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/include CACHE INTERNAL "")
|
||||
@ -360,6 +459,7 @@ set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/includ
|
||||
include_directories(${CUTLASS_INCLUDE_DIR})
|
||||
|
||||
target_compile_features(CUTLASS INTERFACE cxx_std_11)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE})
|
||||
|
||||
if (NOT DEFINED CUTLASS_REVISION)
|
||||
|
||||
@ -448,27 +548,231 @@ endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
include(CTest)
|
||||
enable_testing()
|
||||
if (NOT TARGET test_all)
|
||||
add_custom_target(test_all)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_INSTALL_TESTS ON CACHE BOOL "Install test executables")
|
||||
set(CUTLASS_TEST_EXECUTION_ENVIRONMENT "" CACHE BOOL "Environment in which to invoke unit test executables")
|
||||
|
||||
set(CMAKE_TEST_INSTALL_PREFIX test CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
set(CUTLASS_TEST_INSTALL_PREFIX ${CMAKE_TEST_INSTALL_PREFIX}/cutlass CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
set(CUTLASS_TEST_INSTALL_BINDIR ${CUTLASS_TEST_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR} CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
set(CUTLASS_TEST_INSTALL_LIBDIR ${CUTLASS_TEST_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} CACHE STRING "Test root install location, relative to CMAKE_INSTALL_PREFIX.")
|
||||
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX})
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR})
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_LIBDIR})
|
||||
install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest)
|
||||
|
||||
################################################################################
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cuBLAS.cmake)
|
||||
|
||||
if (CUTLASS_ENABLE_CUBLAS)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_ENABLE_CUBLAS=1)
|
||||
endif()
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cuDNN.cmake)
|
||||
|
||||
if (CUTLASS_ENABLE_CUDNN)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_ENABLE_CUDNN=1)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
if(CUTLASS_ENABLE_TOOLS)
|
||||
set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.config.cmake)
|
||||
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")
|
||||
|
||||
function(cutlass_add_executable_tests NAME TARGET)
|
||||
#
|
||||
# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the
|
||||
# <CMAKE_BINARY_DIR> or the <CMAKE_INSTALL_PREFIX>/<CUTLASS_TEST_INSTALL_PREFIX> after installation.
|
||||
#
|
||||
# NAME: The base name for the test. Can be run with `make <NAME>` or `ctest -R 'c<NAME>'`.
|
||||
# TARGET: The target corresponding to the executable under test.
|
||||
# DISABLE_EXECUTABLE_INSTALL_RULE: An option, if given, that disables creating an install rule for TARGET.
|
||||
# DEPENDS: A list of targets or files on which this test is dependent.
|
||||
# DEPENDEES: A list of targets which should depend on this test.
|
||||
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
|
||||
# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of
|
||||
# options given. If this option is not used, a single test with no arguments is generated.
|
||||
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
|
||||
# test results to speed up test runtime.
|
||||
#
|
||||
|
||||
set(options DISABLE_EXECUTABLE_INSTALL_RULE)
|
||||
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE)
|
||||
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED __DISABLE_TESTS)
|
||||
set(__DISABLE_TESTS OFF)
|
||||
endif()
|
||||
|
||||
if (__RESULT_CACHE_FILE)
|
||||
|
||||
add_custom_command(
|
||||
TARGET ${TARGET}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND}
|
||||
ARGS -E copy ${__RESULT_CACHE_FILE} "$<TARGET_FILE_DIR:${TARGET}>"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
install(
|
||||
TARGETS ${TARGET}
|
||||
RUNTIME DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}
|
||||
)
|
||||
|
||||
if (__RESULT_CACHE_FILE)
|
||||
|
||||
install(
|
||||
FILES ${__RESULT_CACHE_FILE}
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}/
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if (NOT __TEST_COMMAND_OPTIONS)
|
||||
set(__TEST_COMMAND_OPTIONS " ")
|
||||
endif()
|
||||
|
||||
list(LENGTH __TEST_COMMAND_OPTIONS CMD_COUNT)
|
||||
set(CMD_IDX 0)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
add_custom_target(${NAME} DEPENDS ${TARGET} ${__DEPENDS})
|
||||
foreach(DEPENDEE ${__DEPENDEES})
|
||||
add_dependencies(${DEPENDEE} ${NAME})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
foreach(CMD_OPTIONS ${__TEST_COMMAND_OPTIONS})
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
set(TEST_NAME ${NAME}_${CMD_IDX})
|
||||
else()
|
||||
set(TEST_NAME ${NAME})
|
||||
endif()
|
||||
|
||||
# The following rigmarole is needed to deal with spaces and possible quotes in
|
||||
# command line arguments. The options are passed "by reference" as the actual
|
||||
# variable names holding the real options. We then expand these in a way that
|
||||
# preserves any quotes. Note, they have to be in this order for it to work for
|
||||
# all the use cases below.
|
||||
|
||||
set(CMD_OPTIONS ${${CMD_OPTIONS}})
|
||||
list(JOIN CMD_OPTIONS " " TEST_COMMAND_OPTIONS)
|
||||
separate_arguments(CMD_OPTIONS)
|
||||
|
||||
add_custom_target(
|
||||
${TEST_NAME}
|
||||
COMMAND
|
||||
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
|
||||
DEPENDS
|
||||
${TARGET}
|
||||
)
|
||||
|
||||
if (CMD_COUNT GREATER 1)
|
||||
add_dependencies(${NAME} ${TEST_NAME})
|
||||
endif()
|
||||
|
||||
foreach(DEPENDEE ${__DEPENDEES})
|
||||
add_dependencies(${DEPENDEE} ${TEST_NAME})
|
||||
endforeach()
|
||||
|
||||
add_test(
|
||||
NAME c${TEST_NAME}
|
||||
COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${TARGET}> ${CMD_OPTIONS}
|
||||
)
|
||||
|
||||
set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS})
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
# To run the tests from an install package with tests enabled, we need to generate test files
|
||||
# that don't rely on the current directory structure in build.
|
||||
|
||||
set(TEST_NAME c${TEST_NAME})
|
||||
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
|
||||
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})
|
||||
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" @ONLY)
|
||||
|
||||
file(GENERATE
|
||||
OUTPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
|
||||
INPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake"
|
||||
)
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake"
|
||||
DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/
|
||||
)
|
||||
|
||||
set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "")
|
||||
|
||||
endif()
|
||||
|
||||
math(EXPR CMD_IDX "${CMD_IDX} + 1")
|
||||
|
||||
endforeach()
|
||||
|
||||
endfunction()
|
||||
|
||||
if (CUTLASS_ENABLE_TOOLS)
|
||||
add_subdirectory(tools)
|
||||
if (CUTLASS_ENABLE_PROFILER)
|
||||
add_dependencies(test_all test_profiler)
|
||||
endif()
|
||||
endif()
|
||||
if(CUTLASS_ENABLE_EXAMPLES)
|
||||
if (CUTLASS_ENABLE_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
add_dependencies(test_all test_examples)
|
||||
endif()
|
||||
|
||||
if(CUTLASS_ENABLE_TESTS)
|
||||
include(CTest)
|
||||
enable_testing()
|
||||
if (CUTLASS_ENABLE_TESTS)
|
||||
add_subdirectory(test)
|
||||
add_dependencies(test_all test_unit)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_INSTALL_TESTS)
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/cmake")
|
||||
|
||||
file(WRITE "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "# Generated File\n")
|
||||
foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
|
||||
file(APPEND "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
|
||||
endforeach()
|
||||
|
||||
install(
|
||||
FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake"
|
||||
DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
#? install(
|
||||
#? FILES ${CMAKE_BINARY_DIR}/CTestTestfile.cmake
|
||||
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
#? )
|
||||
#?
|
||||
#? install(
|
||||
#? DIRECTORY
|
||||
#? ${CMAKE_BINARY_DIR}/tools
|
||||
#? ${CMAKE_BINARY_DIR}/test
|
||||
#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/
|
||||
#? FILES_MATCHING PATTERN "CTestTestfile.cmake"
|
||||
#? )
|
||||
|
||||
################################################################################
|
||||
|
||||
install(
|
||||
|
||||
@ -16,10 +16,17 @@ Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
Jin Wang
|
||||
Aniket Shivam
|
||||
Chinmay Talegaonkar
|
||||
Shang Zhang
|
||||
Scott Yokim
|
||||
Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
David Tanner
|
||||
Manikandan Ananth
|
||||
|
||||
## CUTLASS Product Manager
|
||||
Matthew Nicely
|
||||
|
||||
## CONTRIBUTORS
|
||||
Timothy Costa
|
||||
@ -55,3 +62,4 @@ Bryce Lelbach
|
||||
Joel McCormack
|
||||
Kyrylo Perelygin
|
||||
|
||||
|
||||
|
||||
77
CUDA.cmake
77
CUDA.cmake
@ -1,23 +1,29 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
@ -74,7 +80,7 @@ find_library(
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
@ -89,10 +95,10 @@ if(NOT TARGET cudart AND CUDART_LIBRARY)
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cudart SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cudart ALIAS cudart)
|
||||
|
||||
|
||||
set_property(
|
||||
TARGET cudart
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
@ -120,7 +126,7 @@ find_library(
|
||||
lib64/stubs
|
||||
lib/stubs
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
@ -135,10 +141,10 @@ if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY)
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cuda_driver SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cuda_driver ALIAS cuda_driver)
|
||||
|
||||
|
||||
set_property(
|
||||
TARGET cuda_driver
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
@ -164,7 +170,7 @@ find_library(
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
@ -179,10 +185,10 @@ if(NOT TARGET nvrtc AND NVRTC_LIBRARY)
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(nvrtc SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
add_library(nvidia::nvrtc ALIAS nvrtc)
|
||||
|
||||
|
||||
set_property(
|
||||
TARGET nvrtc
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
@ -204,7 +210,7 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
||||
# paths by default, so we add it explicitly here.
|
||||
|
||||
function(cutlass_correct_source_file_language_property)
|
||||
if(CUDA_COMPILER MATCHES "clang")
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
foreach(File ${ARGN})
|
||||
if(File MATCHES ".*\.cu$")
|
||||
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
|
||||
@ -213,7 +219,13 @@ function(cutlass_correct_source_file_language_property)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
|
||||
if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON)
|
||||
else()
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED_INIT OFF)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
|
||||
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
@ -235,7 +247,7 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
set(CUDA_FILE_ARGS)
|
||||
set(TARGET_SOURCE_ARGS)
|
||||
|
||||
|
||||
foreach(ARG ${__UNPARSED_ARGUMENTS})
|
||||
if(${ARG} MATCHES ".*\.cu$")
|
||||
list(APPEND CUDA_FILE_ARGS ${ARG})
|
||||
@ -243,7 +255,7 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
list(APPEND TARGET_SOURCE_ARGS ${ARG})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
while(NUM_CUDA_FILE_ARGS GREATER 0)
|
||||
list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH)
|
||||
@ -273,7 +285,6 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_add_library NAME)
|
||||
|
||||
set(options)
|
||||
|
||||
42
LICENSE.txt
42
LICENSE.txt
@ -1,23 +1,27 @@
|
||||
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
22
PUBLICATIONS.md
Normal file
22
PUBLICATIONS.md
Normal file
@ -0,0 +1,22 @@
|
||||
# Publications Using Cutlass
|
||||
|
||||
## 2022
|
||||
|
||||
- ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022.
|
||||
|
||||
- ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022.
|
||||
|
||||
## 2021
|
||||
|
||||
- ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021.
|
||||
|
||||
- ["Real-time Neural Radiance Caching for Path Tracing"](https://d1qx31qr3h6wln.cloudfront.net/publications/paper_4.pdf). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021.
|
||||
|
||||
## 2020
|
||||
|
||||
- ["Scalable Knowledge Graph Analytics at 136 Petaflop/s"](https://www.computer.org/csdl/proceedings-article/sc/2020/999800a061/1oeORDgCM0g). Ramakrishnan Kannan, Piyush Sao, Hao Lu, Drahomira Herrmannova, Vijay Thakkar, Robert Patton, Richard Vuduc, Thomas Potok. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Accelerating Sparse DNN Models without Hardware-Support via Tile-Wise Sparsity
|
||||
"](https://arxiv.org/abs/2008.13006). Cong Guo, Bo Yang Hsueh, Jingwen Leng, Yuxian Qiu, Yue Guan, Zehuan Wang, Xiaoying Jia, Xipeng Li, Minyi Guo, Yuhao Zhu. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2020.
|
||||
|
||||
- ["Strassen's Algorithm Reloaded on GPUs"](https://dl.acm.org/doi/10.1145/3372419). Jianyu Huang, Chenhan D. Yu, Robert A. van de Geijn. _ACM Transactions on Mathematical Software_, March 2020.
|
||||
373
README.md
373
README.md
@ -1,15 +1,15 @@
|
||||

|
||||
|
||||
# CUTLASS 2.2
|
||||
# CUTLASS 2.9
|
||||
|
||||
_CUTLASS 2.2 - June 2020_
|
||||
_CUTLASS 2.9 - April 2022_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
reusable, modular software components abstracted by C++ template classes. These
|
||||
thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
high-performance matrix-multiplication (GEMM) and related computations at all levels
|
||||
and scales within CUDA. It incorporates strategies for hierarchical decomposition and
|
||||
data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes
|
||||
these "moving parts" into reusable, modular software components abstracted by C++ template
|
||||
classes. These thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
|
||||
resulting flexibility simplifies their use as building blocks within custom kernels
|
||||
and applications.
|
||||
@ -20,59 +20,59 @@ multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32), double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
|
||||
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, and Ampere architectures.
|
||||
|
||||
CUTLASS implements high-performance Convolution via the implicit GEMM algorithm.
|
||||
Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of
|
||||
CUTLASS's modular GEMM pipeline.
|
||||
This allows CUTLASS to build convolutions by reusing highly optimized warp-wide GEMM components and below.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](media/docs/functionality.md) for the list of operations
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.2
|
||||
# What's New in CUTLASS 2.9
|
||||
|
||||
CUTLASS 2.2 is a significant update to CUTLASS adding:
|
||||
CUTLASS 2.9 is an update to CUTLASS adding:
|
||||
- [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
- [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
- [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu), [HERK](/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [SYR2K](/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu), [HER2K](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [Out-of-place TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu), and
|
||||
- [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu), [HEMM](/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu)
|
||||
- [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
- [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
- Optimal performance using [CUDA 11.6u2](https://developer.nvidia.com/cuda-downloads)
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
- Ubuntu 16.04
|
||||
- CUDA 10.2
|
||||
|
||||
- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
|
||||
- Deep software pipelines using asynchronous copy
|
||||
- Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745)
|
||||
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
|
||||
# What's New in CUTLASS 2.1
|
||||
|
||||
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
|
||||
|
||||
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
|
||||
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
|
||||
# What's New in CUTLASS 2.0
|
||||
|
||||
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
|
||||
|
||||
- Better performance over 1.x, particularly for kernels targeting Turing Tensor Cores
|
||||
- Robust and durable templates that reliably span the design space
|
||||
- Encapsulated functionality that may be reusable in other contexts
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for more details.**
|
||||
**See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.**
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>
|
||||
<p align="center"><img src=/media/images/cutlass-2.8-gemm-performance.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
|
||||
using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
|
||||
for large matrix dimensions on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/),
|
||||
an [NVIDIA A2](https://www.nvidia.com/en-us/data-center/products/a2/),
|
||||
an [NVIDIA TitanV](https://www.nvidia.com/en-us/titan/titan-v/),
|
||||
and an [NVIDIA GeForce 2080 Ti](https://www.nvidia.com/en-us/geforce/graphics-cards/rtx-2080-ti/)
|
||||
compiled with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
|
||||
performs best when built with the [**CUDA 11.6u2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, CUDA 11.4, and CUDA 11.5.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
@ -80,35 +80,40 @@ We have tested the following environments.
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| Ubuntu 16.04 | GCC 5.4.0 |
|
||||
| | Microsoft Visual Studio 2019|
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
| Ubuntu 20.04 | GCC 10.3.0 |
|
||||
| Ubuntu 21.04 | GCC 11.2.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**Minimum CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
|NVIDIA Tesla P100|6.0|9.2| |
|
||||
|NVIDIA GeForce 1080|6.1|9.2| |
|
||||
|NVIDIA TitanXP|6.1|9.2| |
|
||||
|NVIDIA Tesla V100|7.0|9.2|10.1|
|
||||
|NVIDIA TitanV|7.0|9.2|10.1|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|
||||
|NVIDIA Tesla T4|7.5|10.0|10.2|
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|NVIDIA A10 |8.6|11.1|11.1|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|
||||
For all GPUs, we recommend compiling with the [CUDA 11.6u2 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
for best performance.
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS 2.2 is described in the following documents and the accompanying
|
||||
CUTLASS is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts
|
||||
- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS
|
||||
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
@ -131,19 +136,19 @@ CUTLASS unit tests, examples, and utilities can be build with CMake starting ver
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
```
|
||||
```bash
|
||||
$ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
||||
```
|
||||
|
||||
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
```
|
||||
```bash
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA's Turing GPU architecture
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA's Ampere Architecture
|
||||
```
|
||||
|
||||
From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make.
|
||||
@ -151,7 +156,7 @@ From the `build/` directory, compile and run the CUTLASS unit tests by building
|
||||
The unit tests are organized as several binaries mirroring the top-level namespaces of CUTLASS,
|
||||
and they may be executed in parallel via make's `-j` command line argument.
|
||||
|
||||
```
|
||||
```bash
|
||||
$ make test_unit -j
|
||||
...
|
||||
...
|
||||
@ -182,6 +187,8 @@ include/ # client applications should target this directory
|
||||
|
||||
arch/ # direct exposure of architecture features (including instruction-level GEMMs)
|
||||
|
||||
conv/ # code specialized for convolution
|
||||
|
||||
gemm/ # code specialized for general matrix product computations
|
||||
|
||||
layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory
|
||||
@ -201,34 +208,49 @@ include/ # client applications should target this directory
|
||||
|
||||
```
|
||||
examples/
|
||||
00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs
|
||||
00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs
|
||||
|
||||
01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors
|
||||
01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors
|
||||
|
||||
02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents
|
||||
02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents
|
||||
|
||||
03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS
|
||||
03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS
|
||||
|
||||
04_tile_iterator/ # example demonstrating an iterator over tiles in memory
|
||||
04_tile_iterator/ # example demonstrating an iterator over tiles in memory
|
||||
|
||||
05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation
|
||||
05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation
|
||||
|
||||
06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel
|
||||
06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel
|
||||
|
||||
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
|
||||
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
|
||||
|
||||
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
|
||||
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
|
||||
|
||||
10_planar_complex/ # example demonstrating planar complex GEMM kernels
|
||||
09_turing_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores
|
||||
|
||||
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
10_planar_complex/ # example demonstrating planar complex GEMM kernels
|
||||
|
||||
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
|
||||
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
|
||||
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
|
||||
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
|
||||
|
||||
31_basic_syrk # example demonstrating Symetric rank-K update
|
||||
|
||||
32_basic_trmm #
|
||||
|
||||
33_ampere_3xtf32_tensorop_symm #
|
||||
|
||||
35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores
|
||||
|
||||
40_cutlass_py # example demonstrating CUTLASS with CUDA Python
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
||||
```
|
||||
tools/
|
||||
library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates
|
||||
@ -257,20 +279,85 @@ Instructions for building and running the Unit tests are described in the [Quick
|
||||
The `tools/profiler/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
It can be built as follows:
|
||||
|
||||
```bash
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
$ make cutlass_profiler -j
|
||||
```
|
||||
## Building all GEMM and Convolution kernels (_long_ build times)
|
||||
|
||||
To limit compilation time, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
By default, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||
```
|
||||
Beware, this results in *thousands* of kernels and long build times.
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
|
||||
...
|
||||
$ make cutlass_profiler -j
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling SGEMM kernels is as follows:
|
||||
## Building a subset of GEMM and Convolution kernels (_reduced_ build times)
|
||||
|
||||
To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with
|
||||
wildcard characters may be used to reduce the set of kernels. The following examples show building exactly one
|
||||
or a subset of kernels for NVIDIA Ampere and Turing architecture:
|
||||
|
||||
### Building a subset Tensor Core GEMM kernels
|
||||
|
||||
To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targetting NVIDIA Ampere and Turing architecture,
|
||||
use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*gemm_f16_*_nt_align8
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling a subset of Tensor Core GEMM kernels is as follows:
|
||||
```bash
|
||||
./tools/profiler/cutlass_profiler --kernels=cutlass_tensorop_s*gemm_f16_*_nt_align8 --m=3456 --n=4096 --k=4096
|
||||
|
||||
...
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: gemm
|
||||
Operation: cutlass_tensorop_s1688gemm_f16_256x128_32x2_nt_align8
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
reference_device: Passed
|
||||
cuBLAS: Passed
|
||||
|
||||
Arguments: --gemm_kind=universal --m=3456 --n=4096 --k=4096 --A=f16:column --B=f16:row --C=f32:column --alpha=1 \
|
||||
--beta=0 --split_k_slices=1 --batch_count=1 --op_class=tensorop --accum=f32 --cta_m=256 --cta_n=128 \
|
||||
--cta_k=32 --stages=2 --warps_m=4 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=8 --min_cc=75 \
|
||||
--max_cc=1024
|
||||
|
||||
Bytes: 118489088 bytes
|
||||
FLOPs: 115992428544 flops
|
||||
|
||||
Runtime: 1.55948 ms
|
||||
Memory: 70.7616 GiB/s
|
||||
|
||||
Math: 74378.8 GFLOP/s
|
||||
|
||||
|
||||
|
||||
=============================
|
||||
...
|
||||
```
|
||||
|
||||
### Building one CUDA Core GEMM kernel
|
||||
|
||||
To compile one SGEMM kernel targetting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling single SGEMM CUDA kernel is as follows:
|
||||
```bash
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
|
||||
|
||||
=============================
|
||||
@ -297,9 +384,111 @@ $ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
|
||||
Memory: 24.934 GiB/s
|
||||
|
||||
Math: 17218.4 GFLOP/s
|
||||
|
||||
=============================
|
||||
```
|
||||
|
||||
[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
### Building a subset of Tensor Core Convolution kernels
|
||||
|
||||
To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation
|
||||
and FP16 input targetting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*fprop_optimized_f16
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling a subset of Tensor Core convolution kernels is as follows:
|
||||
|
||||
```bash
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=cutlass_tensorop_s*fprop_optimized_f16 --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3
|
||||
|
||||
...
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: conv2d
|
||||
Operation: cutlass_tensorop_s16816fprop_optimized_f16_128x128_32x5_nhwc
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
reference_device: Passed
|
||||
|
||||
Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1 \
|
||||
--stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f16:nhwc --Filter=f16:nhwc --Output=f32:nhwc \
|
||||
--conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 \
|
||||
--eq_gemm_provider=none --op_class=tensorop --accum=f32 --cta_m=128 --cta_n=128 --cta_k=32 --stages=5 \
|
||||
--warps_m=2 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=16 --min_cc=80 --max_cc=1024
|
||||
|
||||
Bytes: 1130659840 bytes
|
||||
FLOPs: 118482796544 flops
|
||||
|
||||
Runtime: 0.711496 ms
|
||||
Memory: 1479.99 GiB/s
|
||||
|
||||
Math: 166526 GFLOP/s
|
||||
|
||||
=============================
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
### Building one Convolution CUDA kernel
|
||||
|
||||
To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation
|
||||
and FP32 input targetting NVIDIA Ampere and Turing architecture, use the below cmake command line:
|
||||
```bash
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
|
||||
...
|
||||
$ make cutlass_profiler -j16
|
||||
```
|
||||
|
||||
Example command line for profiling one CUDA Core convolution kernel:
|
||||
|
||||
```bash
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3
|
||||
|
||||
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: conv2d
|
||||
Operation: cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
reference_device: Passed
|
||||
|
||||
Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1 \
|
||||
--stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f32:nhwc --Filter=f32:nhwc --Output=f32:nhwc \
|
||||
--conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 \
|
||||
--eq_gemm_provider=none --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
|
||||
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
|
||||
|
||||
Bytes: 2055798784 bytes
|
||||
FLOPs: 118482796544 flops
|
||||
|
||||
Runtime: 7.34266 ms
|
||||
Memory: 260.752 GiB/s
|
||||
|
||||
Math: 16136.2 GFLOP/s
|
||||
|
||||
|
||||
=============================
|
||||
|
||||
```
|
||||
|
||||
## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler
|
||||
- Please follow the links for more CMake examples on selectively compiling CUTLASS kernels:
|
||||
- [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples)
|
||||
- [Implicit GEMM conovlution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples)
|
||||
- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
|
||||
|
||||
# About
|
||||
@ -313,27 +502,33 @@ The official list of CUTLASS developers and contributors is available here: [CON
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
21
cmake/CTestTestfile.config.cmake
Normal file
21
cmake/CTestTestfile.config.cmake
Normal file
@ -0,0 +1,21 @@
|
||||
# Generated file
|
||||
|
||||
if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT})
|
||||
else()
|
||||
set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@)
|
||||
endif()
|
||||
|
||||
if (NOT "@TEST_EXE_DIR@" STREQUAL "")
|
||||
set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@)
|
||||
else()
|
||||
set(TEST_EXE_PATH @TEST_EXE@)
|
||||
endif()
|
||||
|
||||
add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@)
|
||||
|
||||
if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "")
|
||||
set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@")
|
||||
endif()
|
||||
|
||||
set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@)
|
||||
42
cmake/nop.cu
42
cmake/nop.cu
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
27
cuBLAS.cmake
27
cuBLAS.cmake
@ -1,3 +1,30 @@
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
message(STATUS "Configuring cublas ...")
|
||||
|
||||
|
||||
112
cuDNN.cmake
Normal file
112
cuDNN.cmake
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(DEFINED CUDNN_ENABLED)
|
||||
set(CUTLASS_ENABLE_CUDNN ${CUDNN_ENABLED} CACHE BOOL "Enable CUTLASS to build with cuDNN library.")
|
||||
endif()
|
||||
|
||||
if(DEFINED CUTLASS_ENABLE_CUDNN AND NOT CUTLASS_ENABLE_CUDNN)
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuDNN ...")
|
||||
|
||||
find_path(
|
||||
_CUDNN_INCLUDE_DIR cudnn.h
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include
|
||||
$ENV{CUDNN_PATH}/include
|
||||
$ENV{CUDA_PATH}/include
|
||||
${CUDNN_PATH}/include
|
||||
/usr/include)
|
||||
|
||||
find_library(
|
||||
_CUDNN_LIBRARY cudnn
|
||||
HINTS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
$ENV{CUDNN_PATH}/lib64
|
||||
$ENV{CUDNN_PATH}/lib/x64
|
||||
$ENV{CUDNN_PATH}/lib
|
||||
$ENV{CUDA_PATH}/lib64
|
||||
$ENV{CUDA_PATH}/lib/x64
|
||||
$ENV{CUDA_PATH}/lib
|
||||
${CUDNN_PATH}/lib64
|
||||
${CUDNN_PATH}/lib/x64
|
||||
${CUDNN_PATH}/lib
|
||||
/usr/lib/x86_64-linux-gnu
|
||||
/usr/lib)
|
||||
|
||||
if(_CUDNN_INCLUDE_DIR AND _CUDNN_LIBRARY)
|
||||
|
||||
message(STATUS "cuDNN: ${_CUDNN_LIBRARY}")
|
||||
message(STATUS "cuDNN: ${_CUDNN_INCLUDE_DIR}")
|
||||
|
||||
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "cuDNN not found.")
|
||||
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Found")
|
||||
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_CUDNN ${CUDNN_FOUND} CACHE BOOL "Enable CUTLASS to build with cuDNN library.")
|
||||
|
||||
if (CUTLASS_ENABLE_CUDNN AND NOT TARGET cudnn)
|
||||
|
||||
set(CUDNN_INCLUDE_DIR ${_CUDNN_INCLUDE_DIR})
|
||||
set(CUDNN_LIBRARY ${_CUDNN_LIBRARY})
|
||||
|
||||
if(WIN32)
|
||||
add_library(cudnn STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cudnn SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cudnn ALIAS cudnn)
|
||||
|
||||
set_property(
|
||||
TARGET cudnn
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUDNN_LIBRARY})
|
||||
|
||||
target_include_directories(
|
||||
cudnn
|
||||
INTERFACE
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>)
|
||||
|
||||
endif()
|
||||
|
||||
if(CUTLASS_ENABLE_CUDNN AND NOT CUDNN_FOUND)
|
||||
message(FATAL_ERROR "CUTLASS_ENABLE_CUDNN enabled but cuDNN library could not be found.")
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuDNN ... done.")
|
||||
1
docs/_config.yml
Normal file
1
docs/_config.yml
Normal file
@ -0,0 +1 @@
|
||||
theme: jekyll-theme-minimal
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
00_basic_gemm
|
||||
basic_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -148,7 +154,6 @@ cudaError_t CutlassSgemmNN(
|
||||
/// Kernel to initialize a matrix with small integers.
|
||||
__global__ void InitializeMatrix_kernel(
|
||||
float *matrix,
|
||||
int ldm,
|
||||
int rows,
|
||||
int columns,
|
||||
int seed = 0) {
|
||||
@ -157,7 +162,7 @@ __global__ void InitializeMatrix_kernel(
|
||||
int j = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
|
||||
if (i < rows && j < columns) {
|
||||
int offset = i + j * ldm;
|
||||
int offset = i + j * rows;
|
||||
|
||||
// Generate arbitrary elements.
|
||||
int const k = 16807;
|
||||
@ -169,7 +174,7 @@ __global__ void InitializeMatrix_kernel(
|
||||
}
|
||||
|
||||
/// Simple function to initialize a matrix to arbitrary small integers.
|
||||
cudaError_t InitializeMatrix(float *matrix, int ldm, int rows, int columns, int seed = 0) {
|
||||
cudaError_t InitializeMatrix(float *matrix, int rows, int columns, int seed = 0) {
|
||||
|
||||
dim3 block(16, 16);
|
||||
dim3 grid(
|
||||
@ -177,7 +182,7 @@ cudaError_t InitializeMatrix(float *matrix, int ldm, int rows, int columns, int
|
||||
(columns + block.y - 1) / block.y
|
||||
);
|
||||
|
||||
InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed);
|
||||
InitializeMatrix_kernel<<< grid, block >>>(matrix, rows, columns, seed);
|
||||
|
||||
return cudaGetLastError();
|
||||
}
|
||||
@ -185,10 +190,10 @@ cudaError_t InitializeMatrix(float *matrix, int ldm, int rows, int columns, int
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocates device memory for a matrix then fills with arbitrary small integers.
|
||||
cudaError_t AllocateMatrix(float **matrix, int ldm, int rows, int columns, int seed = 0) {
|
||||
cudaError_t AllocateMatrix(float **matrix, int rows, int columns, int seed = 0) {
|
||||
cudaError_t result;
|
||||
|
||||
size_t sizeof_matrix = sizeof(float) * ldm * columns;
|
||||
size_t sizeof_matrix = sizeof(float) * rows * columns;
|
||||
|
||||
// Allocate device memory.
|
||||
result = cudaMalloc(reinterpret_cast<void **>(matrix), sizeof_matrix);
|
||||
@ -209,7 +214,7 @@ cudaError_t AllocateMatrix(float **matrix, int ldm, int rows, int columns, int s
|
||||
}
|
||||
|
||||
// Initialize matrix elements to arbitrary small integers.
|
||||
result = InitializeMatrix(*matrix, ldm, rows, columns, seed);
|
||||
result = InitializeMatrix(*matrix, rows, columns, seed);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to initialize matrix: "
|
||||
@ -304,20 +309,20 @@ cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) {
|
||||
// Allocate matrices in GPU device memory with arbitrary seeds.
|
||||
//
|
||||
|
||||
result = AllocateMatrix(&A, lda, M, K, 0);
|
||||
result = AllocateMatrix(&A, M, K, 0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = AllocateMatrix(&B, ldb, K, N, 17);
|
||||
result = AllocateMatrix(&B, K, N, 17);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(A);
|
||||
return result;
|
||||
}
|
||||
|
||||
result = AllocateMatrix(&C_cutlass, ldc, M, N, 101);
|
||||
result = AllocateMatrix(&C_cutlass, M, N, 101);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(A);
|
||||
@ -325,7 +330,7 @@ cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = AllocateMatrix(&C_reference, ldc, M, N, 101);
|
||||
result = AllocateMatrix(&C_reference, M, N, 101);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(A);
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
01_cutlass_utilities
|
||||
cutlass_utilities.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -119,12 +125,12 @@ cudaError_t cutlass_hgemm_nn(
|
||||
int K,
|
||||
cutlass::half_t alpha,
|
||||
cutlass::half_t const *A,
|
||||
int lda,
|
||||
cutlass::layout::ColumnMajor::Stride::Index lda,
|
||||
cutlass::half_t const *B,
|
||||
int ldb,
|
||||
cutlass::layout::ColumnMajor::Stride::Index ldb,
|
||||
cutlass::half_t beta,
|
||||
cutlass::half_t *C,
|
||||
int ldc) {
|
||||
cutlass::layout::ColumnMajor::Stride::Index ldc) {
|
||||
|
||||
// Define the GEMM operation
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
02_dump_reg_shmem
|
||||
dump_reg_shmem.cu
|
||||
|
||||
@ -1,27 +1,31 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
*modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
*notice, this list of conditions and the following disclaimer in the
|
||||
*documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its
|
||||
*contributors may be used to endorse or promote products derived from this
|
||||
*software without specific prior written permission.
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
|
||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
|
||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
|
||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -69,7 +73,7 @@
|
||||
template <typename Element, typename GmemIterator, typename SmemIterator>
|
||||
__global__ void kernel_dump(typename GmemIterator::Params params,
|
||||
typename GmemIterator::TensorRef ref) {
|
||||
__shared__ Element shared_storage[EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL];
|
||||
extern __shared__ Element shared_storage[];
|
||||
|
||||
// Construct the global iterator and load the data to the fragments.
|
||||
int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
@ -164,8 +168,11 @@ int main() {
|
||||
dim3 grid(1, 1);
|
||||
dim3 block(32, 1, 1);
|
||||
|
||||
int smem_size =
|
||||
int(sizeof(Element) * EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL);
|
||||
|
||||
kernel_dump<Element, GmemIterator, SmemIterator>
|
||||
<<<grid, block>>>(params, matrix.device_ref());
|
||||
<<<grid, block, smem_size, 0>>>(params, matrix.device_ref());
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
|
||||
|
||||
@ -1,28 +1,42 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(TEST_COMMAND_00 RowMajor --extent=16,16)
|
||||
set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
03_visualize_layout
|
||||
visualize_layout.cpp
|
||||
register_layout.cu
|
||||
TEST_COMMAND_OPTIONS
|
||||
TEST_COMMAND_00
|
||||
TEST_COMMAND_01
|
||||
)
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -55,27 +61,49 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
|
||||
new VisualizeLayout<cutlass::layout::ColumnMajorInterleaved<4>>},
|
||||
{"RowMajorInterleaved<4>",
|
||||
new VisualizeLayout<cutlass::layout::RowMajorInterleaved<4>>},
|
||||
// All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling
|
||||
// layout implementation with different templates.
|
||||
//
|
||||
// BMMA 88128 Interleaved-256
|
||||
// BMMA 168256 Interleaved-256
|
||||
{"TensorOpMultiplicand<1,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 256>>},
|
||||
// BMMA 88128 TN kblock512
|
||||
// BMMA 168256 TN kblock512
|
||||
{"TensorOpMultiplicand<1,512>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 512>>},
|
||||
// BMMA 168256 TN kblock1024
|
||||
{"TensorOpMultiplicand<1,1024>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<1, 1024>>},
|
||||
// Integer matrix multiply.int4 8832 Interleaved-64
|
||||
// Integer matrix multiply.int4 16864 Interleaved-64
|
||||
{"TensorOpMultiplicand<4,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 64>>},
|
||||
// Integer matrix multiply.int4 8832 TN kblock128
|
||||
// Integer matrix multiply.int4 16864 TN kblock128
|
||||
{"TensorOpMultiplicand<4,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 128>>},
|
||||
// Integer matrix multiply.int4 16864 TN kblock256
|
||||
{"TensorOpMultiplicand<4,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 256>>},
|
||||
// Integer matrix multiply 8816 Interleaved-32
|
||||
// Integer matrix multiply 16832 Interleaved-32
|
||||
{"TensorOpMultiplicand<8,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 32>>},
|
||||
// Integer matrix multiply 8816 TN kblock64
|
||||
// Integer matrix multiply 16832 TN kblock64
|
||||
{"TensorOpMultiplicand<8,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 64>>},
|
||||
// Integer matrix multiply 16832 TN kblock128
|
||||
{"TensorOpMultiplicand<8,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 128>>},
|
||||
// Matrix Multiply 1688 TN kblock32
|
||||
// Matrix multiply 16816 TN kblock32
|
||||
{"TensorOpMultiplicand<16,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 32>>},
|
||||
// Matrix multiply 1688 NT
|
||||
// Matrix multiply 16816 NT
|
||||
// Matrix multiply 16816 TN kblock64
|
||||
{"TensorOpMultiplicand<16,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 64>>},
|
||||
// Matrix multiply 1688.TF32 TN kblock16
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -32,6 +38,8 @@
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include "options.h"
|
||||
#include "register_layout.h"
|
||||
|
||||
@ -133,6 +141,8 @@ int main(int argc, char const *arg[]) {
|
||||
|
||||
layout_it->second->print_csv(std::cout);
|
||||
|
||||
cudaFree(0); // Ensure CUDA is available.
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
04_tile_iterator
|
||||
tile_iterator.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
05_batched_gemm
|
||||
batched_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -28,12 +34,16 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/gemm/device/gemm_array.h"
|
||||
#include "cutlass/gemm/device/gemm_batched.h"
|
||||
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
/*
|
||||
This example demonstrates how to use cutlass to compute a batched strided gemm.
|
||||
This example demonstrates how to use cutlass to compute a batched strided gemm in two different ways:
|
||||
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
|
||||
matrices of the batch (this is called a strided batched gemm).
|
||||
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
|
||||
In this example, both A and B matrix are non-transpose and column major matrix
|
||||
batched_C = batched_A x batched_B
|
||||
As an example, matrix C can be seen as
|
||||
@ -89,6 +99,45 @@ The stride (batch_stride_C) between the first element of two batches is k
|
||||
|
||||
*/
|
||||
|
||||
cudaError_t cutlass_array_sgemm(
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
float alpha,
|
||||
float const * const *A,
|
||||
int lda,
|
||||
float const * const *B,
|
||||
int ldb,
|
||||
float * const *C,
|
||||
int ldc,
|
||||
float beta,
|
||||
int batch_count) {
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmArray<
|
||||
float, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::ColumnMajor,
|
||||
float, cutlass::layout::ColumnMajor
|
||||
>;
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op({
|
||||
{m, n, k},
|
||||
A, lda,
|
||||
B, ldb,
|
||||
C, ldc,
|
||||
C, ldc,
|
||||
{alpha, beta},
|
||||
batch_count
|
||||
});
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
cudaError_t cutlass_strided_batched_sgemm(
|
||||
int m,
|
||||
int n,
|
||||
@ -188,7 +237,11 @@ cudaError_t strided_batched_gemm_nn_reference(
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
cudaError_t run_batched_gemm(bool use_array) {
|
||||
|
||||
const char* gemm_desc = use_array ? "array" : "strided batched";
|
||||
std::cout << "Running " << gemm_desc << " gemm" << std::endl;
|
||||
|
||||
// Arbitrary problem size
|
||||
int const m = 520;
|
||||
@ -293,11 +346,69 @@ int main() {
|
||||
}
|
||||
|
||||
// run cutlass
|
||||
result = cutlass_strided_batched_sgemm(
|
||||
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
||||
beta, batch_count);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
if (use_array) {
|
||||
// allocate the host memory for the pointers to the matrices of the batch
|
||||
std::vector<float*> host_ptr_A(batch_count);
|
||||
std::vector<float*> host_ptr_B(batch_count);
|
||||
std::vector<float*> host_ptr_C(batch_count);
|
||||
|
||||
// permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride
|
||||
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
|
||||
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
|
||||
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
|
||||
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
|
||||
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
|
||||
}
|
||||
|
||||
// allocate the corresponding device memory
|
||||
float const **ptr_A;
|
||||
float const **ptr_B;
|
||||
float **ptr_C;
|
||||
|
||||
result = cudaMalloc(&ptr_A, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&ptr_B, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&ptr_C, batch_count * sizeof(float*));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// copy the matrix pointers to the device
|
||||
result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
} else {
|
||||
result = cutlass_strided_batched_sgemm(
|
||||
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
||||
beta, batch_count);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
}
|
||||
|
||||
// copy device memory to host
|
||||
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
@ -314,7 +425,7 @@ int main() {
|
||||
|
||||
// Expect bit-level accuracy for this simple example
|
||||
if (ref_C != result_C) {
|
||||
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl;
|
||||
std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl;
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
@ -335,9 +446,19 @@ int main() {
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
int main() {
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
for (bool use_array : {false, true}) {
|
||||
result = run_batched_gemm(use_array);
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Exit.
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
06_splitK_gemm
|
||||
splitk_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -205,7 +211,7 @@ int run() {
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.nk()); // <- Create matrix B with dimensions N x K
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
07_volta_tensorop_gemm
|
||||
volta_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -67,7 +73,7 @@ beta * C).
|
||||
Now that we setup the properties of data, we have to setup properties of computation.
|
||||
|
||||
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32,
|
||||
64x64x4, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally
|
||||
64x64x32, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally
|
||||
deduce the amount of threads needed per thread-block, amount of shared memory, storing data in
|
||||
bank-conflict free manner, and ton of other variables required to compose, intialize and launch a
|
||||
high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from
|
||||
@ -284,8 +290,12 @@ int run() {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
08_turing_tensorop_gemm
|
||||
turing_tensorop_gemm.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -188,31 +194,6 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
|
||||
int run() {
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 75)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
// Return 0 so tests are considered passing if run on unsupported platforms.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
@ -291,8 +272,12 @@ int run() {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
@ -337,18 +322,37 @@ int run() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
// Returning zero so this test passes when built on older Toolkits.
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 75)) {
|
||||
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
|
||||
36
examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt
Normal file
36
examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
09_turing_tensorop_conv2dfprop
|
||||
turing_tensorop_conv2dfprop.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,770 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
|
||||
This example shows how to run convolution kernels using functions and data structures
|
||||
provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU.
|
||||
|
||||
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
|
||||
of GPU easily.
|
||||
|
||||
CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
|
||||
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
|
||||
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
|
||||
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
|
||||
threadblock-tile (tile size computed by a threadblock).
|
||||
|
||||
In thie example, we split variable initialization into
|
||||
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
|
||||
can view them (logical to physical mapping)
|
||||
2. Setting up computation properties : describes how the above set tensors will be used to compute
|
||||
output of convolution.
|
||||
|
||||
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
|
||||
with alpha, beta as the equation for convolution is C = alpha * Conv(A, B) + beta * C. In CUTLASS,
|
||||
the kernels first compute Conv(A, B) and leave the rest of the computation to end of the kernel as
|
||||
alpha * X + beta * C is a simple element-wise operation on X (Conv(A, B)) and C. We call this as
|
||||
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
|
||||
ElementComputeEpilogue = float. We want to use MMA instructions on Turing and they support 4-bit
|
||||
signed integer. But int4b_t is not fully supported by Nvidia software stack, so CUTLASS introduces
|
||||
cutlass::int4b_t. We use the data type for elements in input tensor A and B as cutlass::int4b_t. We
|
||||
convey this to CUTLASS kernel by initializing template variables ElementAccumulator (int32_t),
|
||||
ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB (cutlass::int4b_t),
|
||||
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
|
||||
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
|
||||
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
|
||||
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
|
||||
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
|
||||
computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
Now that we setup the properties of data, we have to setup properties of computation.
|
||||
|
||||
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128,
|
||||
64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it
|
||||
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
|
||||
data in bank-conflict free manner, and ton of other variables required to compose, intialize and
|
||||
launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer
|
||||
from understanding and coding complicated hardware optimizations which can easily go wrong.
|
||||
|
||||
CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
|
||||
constitute the whole process of loading input data from global memory to shared memory, loading data
|
||||
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
|
||||
sequence shows a typical mma pipeline.
|
||||
|
||||
tensor in global memory -> registers -> tile in shared memory -> registers -> mma -> registers ->
|
||||
output to global memory
|
||||
|
||||
The problem with single pipeline is, each stage is synchronous which means, each stage has to wait
|
||||
until the previous finished executing. There are stages in the pipeline which do not have fixed
|
||||
latency, for example, the loads from global memory and shared memory. Therefore, we can add one more
|
||||
pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads.
|
||||
Finally, the pipeline in a kernel looks like
|
||||
|
||||
(1) tensor in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5)
|
||||
mma -> (6) registers -> (7) output to global memory (1) <null> -> (2) <null> -> (3) tensor in global
|
||||
memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers ->
|
||||
(9) output to global memory
|
||||
|
||||
This way, you can hide the second global memory load latency by doing computation on already loaded
|
||||
input data.
|
||||
|
||||
There are few more template variables initialized such as, which threadblock tile of output matrix
|
||||
is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on.
|
||||
|
||||
These are all put together to create a template variable which describes CUTLASS Implicit GEMM
|
||||
kernel using cutlass::conv::device::ImplicitGemm template.
|
||||
|
||||
The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it.
|
||||
We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
|
||||
in the way of learning CUTLASS.
|
||||
|
||||
Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS
|
||||
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
|
||||
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
|
||||
important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space
|
||||
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
|
||||
arguments created to intialize CUTLASS kernel then, the kernel is launched.
|
||||
|
||||
In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to
|
||||
compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = int32_t; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::int4b_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::int4b_t; // Data type of elements in input tensor
|
||||
using ElementOutput = cutlass::int4b_t; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm75;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 2;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationClamp<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
8, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kAnalytic
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(false),
|
||||
measure_performance(true),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of int4b_t elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 32 elements.
|
||||
//
|
||||
int const kAlignment = 32;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "09_turing_tensorop_conv2dfprop example\n\n"
|
||||
<< " This example uses Turing's Tensor Core operators on int4 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
<< " --h=<int> Input tensor extent H\n"
|
||||
<< " --w=<int> Input tensor extent W\n"
|
||||
<< " --c=<int> Input tensor extent C\n"
|
||||
<< " --k=<int> Filter extent K\n"
|
||||
<< " --r=<int> Filter extent R\n"
|
||||
<< " --s=<int> Filter extent S\n\n"
|
||||
<< " --alpha=<float> Epilogue scalar alpha\n"
|
||||
<< " --beta=<float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(7),
|
||||
ElementInputA(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor C for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_ref_c.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
// mode (kCrossCorrelation or kConvolution)
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices);
|
||||
|
||||
// Construct ImplicitGemm::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename ImplicitGemm::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemm implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on host...\n";
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::host::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_c.host_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_c.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_c.host_view(),
|
||||
tensor_ref_c.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "09_tensor_conv_workspace_conv2dfprop_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 7 || (props.major == 7 && props.minor >= 5))) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3},
|
||||
{56, 56, 256, 64, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1},
|
||||
{56, 56, 256, 128, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3},
|
||||
{28, 28, 128, 512, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1},
|
||||
{28, 28, 512, 256, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3},
|
||||
{14, 14, 256, 1024, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1},
|
||||
{14, 14, 1024, 512, 1, 1},
|
||||
{7, 7, 512, 512, 3, 3},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
|
||||
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
# Planar Complex GEMM example
|
||||
cutlass_example_add_executable(
|
||||
10_planar_complex
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -50,7 +56,7 @@
|
||||
To build strictly the planar complex kernels needed for general application, execute the following
|
||||
CMake command in an empty build directory.
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
|
||||
|
||||
This builds all planar complex GEMM variants for Volta and Turing architectures.
|
||||
@ -59,7 +65,7 @@
|
||||
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
|
||||
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn
|
||||
|
||||
$ make 10_planar_complex
|
||||
@ -167,15 +173,15 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i=<f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i=<f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n"
|
||||
@ -275,10 +281,10 @@ public:
|
||||
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
|
||||
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
|
||||
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
|
||||
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
|
||||
@ -526,6 +532,11 @@ int main(int argc, char const **args) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond.
|
||||
//
|
||||
// fall through
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -1,26 +1,34 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
# Planar Complex Array GEMM example
|
||||
cutlass_example_add_executable(
|
||||
11_planar_complex_array
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -48,7 +54,7 @@
|
||||
To build strictly the planar complex kernels needed for general application, execute the following
|
||||
CMake command in an empty build directory.
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex
|
||||
|
||||
This builds all planar complex GEMM variants for Volta and Turing architectures.
|
||||
@ -57,7 +63,7 @@
|
||||
specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for
|
||||
the 'CN' layout configuration (conjugate A operand with both A and B as column-major).
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \
|
||||
-DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn
|
||||
|
||||
$ make 11_planar_complex_array
|
||||
@ -165,15 +171,15 @@ struct Options {
|
||||
<< " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m <int> GEMM M dimension\n"
|
||||
<< " --n <int> GEMM N dimension\n"
|
||||
<< " --k <int> GEMM K dimension\n"
|
||||
<< " --batch <int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha <f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i <f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta <f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i <f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n";
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
||||
<< " --alpha_i=<f32> Epilogue scalar alpha (imaginary part)\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
||||
<< " --beta_i=<f32> Epilogue scalar beta (imaginary part)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n";
|
||||
@ -292,10 +298,11 @@ public:
|
||||
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
|
||||
|
||||
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
|
||||
typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
|
||||
|
||||
|
||||
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
|
||||
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();
|
||||
@ -586,6 +593,11 @@ int main(int argc, char const **args) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond.
|
||||
//
|
||||
// fall through
|
||||
}
|
||||
|
||||
//
|
||||
// Parse options
|
||||
|
||||
@ -1,25 +1,33 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
12_gemm_bias_relu
|
||||
gemm_bias_relu.cu
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -48,11 +54,21 @@ using ElementInputA = cutlass::half_t; // <- data type of elements
|
||||
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = float; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
// The code section below describes matrix layout of input and output matrices.
|
||||
// Column Major for Matrix A, B and C.
|
||||
|
||||
// Note that if the output is column major, the bias has to be per row. i.e. every row has different bias.
|
||||
// If the output is row major, the bias has to be per column, i.e. every column has different bias.
|
||||
// Below list some other notices:
|
||||
// 1) we only have row major epilogue.
|
||||
// 2) we swap A and B if the output is column major then we can still use the
|
||||
// row major epilogue.
|
||||
// 3) Mx1 bias vector becomes 1xM after the swapping/transposing.
|
||||
// 4) we can use the existing OutputIterator to load 1xM bias vector.
|
||||
|
||||
using LayoutInputA = cutlass::layout::ColumnMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using LayoutOutput = cutlass::layout::ColumnMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
@ -73,17 +89,18 @@ using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSw
|
||||
|
||||
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
|
||||
//
|
||||
// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + beta * c_ij )
|
||||
// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + c_ij )
|
||||
//
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
|
||||
// vectorized memory access. For half
|
||||
// precision, it's 8 elements. This becomes
|
||||
// the vector width of math instructions in
|
||||
// epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
|
||||
// vectorized memory access. For half
|
||||
// precision, it's 8 elements. This becomes
|
||||
// the vector width of math instructions in
|
||||
// epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue, // <- data type for alpha in linear combination function
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // <- alpha x C + bias
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 2;
|
||||
@ -106,21 +123,6 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
@ -132,7 +134,7 @@ int run() {
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.nk()); // <- Create matrix B with dimensions N x K
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias(
|
||||
{problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1
|
||||
@ -175,9 +177,8 @@ int run() {
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
// Initialize alpha for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
@ -193,7 +194,7 @@ int run() {
|
||||
// to project away the N dimension by setting the stride to zero.
|
||||
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
{alpha}, // <- alpha
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
@ -205,8 +206,12 @@ int run() {
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
@ -234,7 +239,6 @@ int run() {
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
0,
|
||||
tensor_c_bias.device_ref(),
|
||||
tensor_ref_d.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
@ -249,7 +253,7 @@ int run() {
|
||||
for (int j = 0; j < problem_size.n(); ++j) {
|
||||
tensor_ref_d.at({i, j}) = std::max(
|
||||
ElementOutput(0),
|
||||
ElementOutput(tensor_ref_d.at({i, j}) + beta * tensor_c_bias.at({i, 0}))
|
||||
ElementOutput(tensor_ref_d.at({i, j}) + tensor_c_bias.at({i, 0}))
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -266,17 +270,35 @@ int run() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_fused_two_gemms
|
||||
fused_gemm.cu
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
13_fused_two_gemms
|
||||
PRIVATE
|
||||
.
|
||||
)
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
*/
|
||||
|
||||
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
|
||||
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
run_nonfused_gemm_f16();
|
||||
run_fused_gemm_f16();
|
||||
run_nonfused_gemm_s8();
|
||||
run_fused_gemm_s8();
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,289 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false>
|
||||
struct DefaultB2bMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, false> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>,
|
||||
ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp, true>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
|
||||
static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>, ElementA,
|
||||
LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
|
||||
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
82
examples/13_two_tensor_op_fusion/CMakeLists.txt
Normal file
82
examples/13_two_tensor_op_fusion/CMakeLists.txt
Normal file
@ -0,0 +1,82 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
include_directories(
|
||||
.
|
||||
)
|
||||
|
||||
add_custom_target(13_fused_two_gemms)
|
||||
|
||||
add_custom_target(13_fused_two_convs)
|
||||
|
||||
add_custom_target(13_two_tensor_op_fusion
|
||||
DEPENDS 13_fused_two_gemms
|
||||
13_fused_two_convs
|
||||
)
|
||||
|
||||
foreach(FUSION_CONV_EXAMPLE
|
||||
fused_two_convs_f16_sm75_rf
|
||||
fused_two_convs_f16_sm75_shmem
|
||||
fused_two_convs_f16_sm80_rf
|
||||
fused_two_convs_f16_sm80_shmem
|
||||
fused_two_convs_s8_sm75_rf
|
||||
fused_two_convs_s8_sm75_shmem
|
||||
fused_two_convs_s8_sm80_rf
|
||||
fused_two_convs_s8_sm80_shmem
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_${FUSION_CONV_EXAMPLE}
|
||||
${FUSION_CONV_EXAMPLE}.cu
|
||||
)
|
||||
|
||||
add_dependencies(13_fused_two_convs 13_${FUSION_CONV_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
foreach(FUSION_GEMM_EXAMPLE
|
||||
fused_two_gemms_f16_sm75_rf
|
||||
fused_two_gemms_f16_sm75_shmem
|
||||
fused_two_gemms_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_shmem
|
||||
fused_two_gemms_s8_sm75_rf
|
||||
fused_two_gemms_s8_sm75_shmem
|
||||
fused_two_gemms_s8_sm80_rf
|
||||
fused_two_gemms_s8_sm80_shmem
|
||||
)
|
||||
cutlass_example_add_executable(
|
||||
13_${FUSION_GEMM_EXAMPLE}
|
||||
${FUSION_GEMM_EXAMPLE}.cu
|
||||
)
|
||||
|
||||
add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE})
|
||||
|
||||
endforeach()
|
||||
|
||||
95
examples/13_two_tensor_op_fusion/README.md
Normal file
95
examples/13_two_tensor_op_fusion/README.md
Normal file
@ -0,0 +1,95 @@
|
||||
# Introduction
|
||||
|
||||
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_fusion.png></p>
|
||||
|
||||
When running two unfused GEMM/Conv operations, each operation loads one input
|
||||
activation matrix, one weight matrix (or filter matrix) from the memory and then
|
||||
stores the result activation matrix back to the memory.
|
||||
|
||||
When the two GEMM/Conv operations are fused together, the mainloops of the two
|
||||
GEMMs/Convs run back to back in a single kernel. The output accumulator of the
|
||||
1st GEMM/Conv will be stored in the register file and reused as the activation
|
||||
input of the 2nd GEMM/Conv. This saves a round trip to memory for the activation
|
||||
matrix.
|
||||
|
||||
|
||||
This example computes the following:
|
||||
- 1st GEMM/Conv: D0 = relu(alpha0 .\* A0 \*\* B0)
|
||||
- 2nd GEMM/Conv: D1 = relu(alpha1 .\* D0 \*\* B1 + beta1 .\* C1)
|
||||
|
||||
In the above equation, operator \*\* can be matrix multiplication or convolution operation.
|
||||
|
||||
# Implementation Details
|
||||
|
||||
In order to run two GEMM/Convs in a single kernel, the example requires the same number of
|
||||
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
|
||||
2 GEMMs/Convs.
|
||||
|
||||
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
|
||||
input activation, the example enforces the following two constraints:
|
||||
|
||||
- thread_block_tile_N = problem_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_block_resident_fusion.png></p>
|
||||
|
||||
This constraint ensures that each threadblock loads the entire weight/filter matrix in
|
||||
addition to its own input activation tile. Therefore the input activation tile of the
|
||||
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
|
||||
operation can be fully block-resident.
|
||||
|
||||
- warp_tile_N = thread_block_tile_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_rf_resident_fusion.png></p>
|
||||
|
||||
This constraint ensures that each warp loads the entire weight/filter kBlock in
|
||||
addition to its own input activation tile. Therefore the input activation warp tile of the
|
||||
2nd GEMM/Conv only depends on the output warp accumulator of the 1st GEMM/Conv in the
|
||||
register file, and the operation can be fully register-file-resident.
|
||||
|
||||
On the other hand, this constraint can be relaxed if the output accumulator of the 1st GEMM/CONV
|
||||
is staged in the shared memory and then used as input for the 2nd GEMM/CONV. In this case, the
|
||||
input of each warp tile can be loaded from the shared memory so they do not need to be RF-resident,
|
||||
therefore each warp does not need to store the entire input matrix of 2nd GEMM in its RF. This is
|
||||
illustrated in the diagram below.
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_shmem_resident_fusion.png></p>
|
||||
|
||||
|
||||
When applying the above constraint to convolutions, it is required that the 2nd Convolution
|
||||
kernel doesn't have halos such that data used by each threadblock doesn't depend on any other
|
||||
threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings.
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
702
examples/13_two_tensor_op_fusion/b2b_conv2d_run.h
Normal file
702
examples/13_two_tensor_op_fusion/b2b_conv2d_run.h
Normal file
@ -0,0 +1,702 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
#include "cutlass/reduction/device/reduce_split_k.h"
|
||||
#include "cutlass/reduction/thread/reduction_operators.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
|
||||
template <typename Conv2d0_, typename Conv2d1_>
|
||||
class B2bNonFusedConv2dRun {
|
||||
public:
|
||||
|
||||
using Conv2d0 = Conv2d0_;
|
||||
using Conv2d1 = Conv2d1_;
|
||||
using ElementAccumulator = typename Conv2d0::ElementAccumulator;
|
||||
using ElementCompute = typename Conv2d0::ElementCompute;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator;
|
||||
static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator,
|
||||
"Fused convolution operators must be the same");
|
||||
|
||||
public:
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
B2bNonFusedConv2dRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
void initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
int scope;
|
||||
int bits = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits <= 16) {
|
||||
scope = 2;
|
||||
}
|
||||
else {
|
||||
scope = 8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope, -scope, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
uint64_t seed = 2019) {
|
||||
|
||||
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
|
||||
tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_computed.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
initialize(problem_size_0, problem_size_1);
|
||||
|
||||
// configure the operator
|
||||
Conv2d0 conv2d_op_0;
|
||||
Conv2d1 conv2d_op_1;
|
||||
|
||||
typename Conv2d0::Arguments conv2d_args_0(
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
|
||||
tensor_D0_computed.device_ref(),
|
||||
{alpha0, beta0},
|
||||
split_k_mode
|
||||
);
|
||||
typename Conv2d1::Arguments conv2d_args_1(
|
||||
problem_size_1,
|
||||
tensor_D0_computed.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
|
||||
tensor_D1_computed.device_ref(),
|
||||
{alpha1, beta1},
|
||||
split_k_mode
|
||||
);
|
||||
|
||||
|
||||
cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = conv2d_op_1.initialize(conv2d_args_1);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = conv2d_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
status = conv2d_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run Conv2d
|
||||
//
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
cudaEventCreate(&stop2);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
// run conv2d operator
|
||||
status = conv2d_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
// run conv2d operator
|
||||
status = conv2d_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop2);
|
||||
cudaDeviceSynchronize();
|
||||
float conv2d0Time, conv2d1Time, totalTime;
|
||||
cudaEventElapsedTime(&conv2d0Time, start, stop1);
|
||||
cudaEventElapsedTime(&conv2d1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
||||
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0_computed.sync_host();
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename Conv2d0::ElementA,
|
||||
typename Conv2d0::LayoutA,
|
||||
typename Conv2d0::ElementB,
|
||||
typename Conv2d0::LayoutB,
|
||||
typename Conv2d0::ElementC,
|
||||
typename Conv2d0::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
}
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename Conv2d1::ElementA,
|
||||
typename Conv2d1::LayoutA,
|
||||
typename Conv2d1::ElementB,
|
||||
typename Conv2d1::LayoutB,
|
||||
typename Conv2d1::ElementC,
|
||||
typename Conv2d1::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_1,
|
||||
tensor_D0_reference.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
|
||||
tensor_D1_reference.device_ref(),
|
||||
alpha1,
|
||||
beta1);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
CHECK_TRUE(result == cudaSuccess);
|
||||
|
||||
// sync host (copy device data to host) for dumping error output in case of mismatches
|
||||
tensor_D0_reference.sync_host();
|
||||
tensor_D1_reference.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_D1_computed.host_view(),
|
||||
tensor_D1_reference.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
|
||||
if (!passed) {
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bImplicitGemm_device_nonfused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream results(fname.str());
|
||||
|
||||
results << problem_size_0 << std::endl;
|
||||
results << problem_size_1 << std::endl;
|
||||
|
||||
results
|
||||
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
|
||||
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <typename B2bConv2d_>
|
||||
class B2bFusedConv2dRun {
|
||||
public:
|
||||
|
||||
using B2bConv2d = B2bConv2d_;
|
||||
using ElementAccumulator = typename B2bConv2d::ElementAccumulator;
|
||||
using ElementCompute = typename B2bConv2d::ElementCompute;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator;
|
||||
|
||||
public:
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
B2bFusedConv2dRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
void initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
int scope;
|
||||
int bits = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits <= 16) {
|
||||
scope = 2;
|
||||
}
|
||||
else {
|
||||
scope = 8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope, -scope, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
ElementCompute alpha0,
|
||||
ElementCompute alpha1,
|
||||
uint64_t seed = 2019) {
|
||||
|
||||
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
initialize(problem_size_0, problem_size_1, alpha0, alpha1);
|
||||
|
||||
// configure the operator
|
||||
B2bConv2d b2b_conv2d_op;
|
||||
|
||||
typename B2bConv2d::Arguments b2b_conv2d_args(
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
|
||||
tensor_D1_computed.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
split_k_mode
|
||||
);
|
||||
|
||||
cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
|
||||
<< " problem_size_0.K = problem_size_1.C\n"
|
||||
<< " problem_size_1.R = problem_size_1.S = 1\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.K\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
|
||||
}
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_conv2d_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the Conv2d
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
|
||||
// run conv2d operator
|
||||
status = b2b_conv2d_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float conv2dTime;
|
||||
cudaEventElapsedTime(&conv2dTime, start, stop);
|
||||
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename B2bConv2d::ElementA,
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
}
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename B2bConv2d::ElementA,
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_1,
|
||||
tensor_D0_reference.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
|
||||
tensor_D1_reference.device_ref(),
|
||||
alpha1,
|
||||
beta1);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
CHECK_TRUE(result == cudaSuccess);
|
||||
|
||||
// sync host (copy device data to host) for dumping error output in case of mismatches
|
||||
tensor_D0_reference.sync_host();
|
||||
tensor_D1_reference.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_D1_computed.host_view(),
|
||||
tensor_D1_reference.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
|
||||
if (!passed) {
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bImplicitGemm_device_fused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream results(fname.str());
|
||||
|
||||
results << problem_size_0 << std::endl;
|
||||
results << problem_size_1 << std::endl;
|
||||
|
||||
results
|
||||
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -121,7 +127,9 @@ struct B2bNonFusedGemmRun
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
@ -222,6 +230,14 @@ struct B2bNonFusedGemmRun
|
||||
status = gemm_op_1.initialize(arguments_1);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = gemm_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
status = gemm_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
@ -233,13 +249,13 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
|
||||
status = gemm_op_1();
|
||||
|
||||
@ -252,9 +268,9 @@ struct B2bNonFusedGemmRun
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
|
||||
std::cout << "total time " << totalTime / 100.0 << " ms\n";
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
@ -415,7 +431,9 @@ struct B2bFusedGemmRun
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
@ -433,10 +451,6 @@ struct B2bFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
// cutlass::HostTensor<
|
||||
// typename B2bGemm::ElementC,
|
||||
// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
@ -499,10 +513,27 @@ struct B2bFusedGemmRun
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
@ -513,7 +544,7 @@ struct B2bFusedGemmRun
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -523,9 +554,8 @@ struct B2bFusedGemmRun
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / 100.0 << " ms\n";
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
//tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
@ -593,7 +623,6 @@ struct B2bFusedGemmRun
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
// << "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
730
examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h
Normal file
730
examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h
Normal file
@ -0,0 +1,730 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
#include "cutlass/reduction/device/reduce_split_k.h"
|
||||
#include "cutlass/reduction/thread/reduction_operators.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
|
||||
template <typename Conv2d0_, typename Conv2d1_, int InterleavedK>
|
||||
class B2bInterleavedNonFusedConv2dRun {
|
||||
public:
|
||||
|
||||
using Conv2d0 = Conv2d0_;
|
||||
using Conv2d1 = Conv2d1_;
|
||||
using ElementAccumulator = typename Conv2d0::ElementAccumulator;
|
||||
using ElementCompute = typename Conv2d0::ElementCompute;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator;
|
||||
static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator,
|
||||
"Fused convolution operators must be the same");
|
||||
|
||||
public:
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
void initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
int scope;
|
||||
int bits = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits <= 16) {
|
||||
scope = 2;
|
||||
}
|
||||
else {
|
||||
scope = 8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope, -scope, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) {
|
||||
|
||||
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
|
||||
tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
|
||||
//Reorder B0 and B1
|
||||
cutlass::reorder_convK<InterleavedK, InterleavedK>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0));
|
||||
cutlass::reorder_convK<InterleavedK, InterleavedK>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_computed.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
initialize(problem_size_0, problem_size_1);
|
||||
|
||||
// configure the operator
|
||||
Conv2d0 conv2d_op_0;
|
||||
Conv2d1 conv2d_op_1;
|
||||
|
||||
typename Conv2d0::Arguments conv2d_args_0(
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0_computed.device_ref(),
|
||||
{alpha0, beta0},
|
||||
split_k_mode
|
||||
);
|
||||
typename Conv2d1::Arguments conv2d_args_1(
|
||||
problem_size_1,
|
||||
tensor_D0_computed.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1_computed.device_ref(),
|
||||
{alpha1, beta1},
|
||||
split_k_mode
|
||||
);
|
||||
|
||||
|
||||
cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = conv2d_op_1.initialize(conv2d_args_1);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = conv2d_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
status = conv2d_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run Conv2d
|
||||
//
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
cudaEventCreate(&stop2);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
// run conv2d operator
|
||||
status = conv2d_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
// run conv2d operator
|
||||
status = conv2d_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop2);
|
||||
cudaDeviceSynchronize();
|
||||
float conv2d0Time, conv2d1Time, totalTime;
|
||||
cudaEventElapsedTime(&conv2d0Time, start, stop1);
|
||||
cudaEventElapsedTime(&conv2d1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
|
||||
std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0_computed.sync_host();
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename Conv2d0::ElementA,
|
||||
typename Conv2d0::LayoutA,
|
||||
typename Conv2d0::ElementB,
|
||||
typename Conv2d0::LayoutB,
|
||||
typename Conv2d0::ElementC,
|
||||
typename Conv2d0::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<typename Conv2d0::ElementC, ElementCompute>
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
}
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename Conv2d1::ElementA,
|
||||
typename Conv2d1::LayoutA,
|
||||
typename Conv2d1::ElementB,
|
||||
typename Conv2d1::LayoutB,
|
||||
typename Conv2d1::ElementC,
|
||||
typename Conv2d1::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<typename Conv2d1::ElementC, ElementCompute>
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_1,
|
||||
tensor_D0_reference.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1_reference.device_ref(),
|
||||
alpha1,
|
||||
beta1);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
CHECK_TRUE(result == cudaSuccess);
|
||||
|
||||
// sync host (copy device data to host) for dumping error output in case of mismatches
|
||||
tensor_D0_reference.sync_host();
|
||||
tensor_D1_reference.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_D1_computed.host_view(),
|
||||
tensor_D1_reference.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
|
||||
if (!passed) {
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bImplicitGemm_device_interleaved_nonfused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream results(fname.str());
|
||||
|
||||
results << problem_size_0 << std::endl;
|
||||
results << problem_size_1 << std::endl;
|
||||
|
||||
results
|
||||
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
|
||||
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <typename B2bConv2d_, int InterleavedK>
|
||||
class B2bInterleavedFusedConv2dRun {
|
||||
public:
|
||||
|
||||
using B2bConv2d = B2bConv2d_;
|
||||
using ElementAccumulator = typename B2bConv2d::ElementAccumulator;
|
||||
using ElementCompute = typename B2bConv2d::ElementCompute;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator;
|
||||
|
||||
public:
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0_reordered;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
B2bInterleavedFusedConv2dRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
|
||||
|
||||
}
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
void initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
int scope;
|
||||
int bits = cutlass::sizeof_bits<Element>::value;
|
||||
|
||||
if (bits <= 16) {
|
||||
scope = 2;
|
||||
}
|
||||
else {
|
||||
scope = 8;
|
||||
}
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, scope, -scope, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
ElementCompute alpha0,
|
||||
ElementCompute alpha1,
|
||||
uint64_t seed = 2019) {
|
||||
|
||||
tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
|
||||
tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
initialize_tensor(tensor_A0.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
|
||||
initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
|
||||
initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
|
||||
initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
|
||||
initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
|
||||
|
||||
//Reorder B0 and B1
|
||||
cutlass::reorder_convK<16, InterleavedK>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0));
|
||||
cutlass::reorder_convK<InterleavedK, InterleavedK>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1));
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0_reference.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1_computed.sync_device();
|
||||
tensor_D1_reference.sync_device();
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_0,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_1,
|
||||
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
initialize(problem_size_0, problem_size_1, alpha0, alpha1);
|
||||
|
||||
// configure the operator
|
||||
B2bConv2d b2b_conv2d_op;
|
||||
|
||||
typename B2bConv2d::Arguments b2b_conv2d_args(
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1_computed.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
split_k_mode
|
||||
);
|
||||
|
||||
cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
|
||||
<< " problem_size_0.K = problem_size_1.C\n"
|
||||
<< " problem_size_1.R = problem_size_1.S = 1\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.K\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
|
||||
}
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = b2b_conv2d_op.initialize(b2b_conv2d_args);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_conv2d_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the Conv2d
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
|
||||
// run conv2d operator
|
||||
status = b2b_conv2d_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float conv2dTime;
|
||||
cudaEventElapsedTime(&conv2dTime, start, stop);
|
||||
std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D1_computed.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename B2bConv2d::ElementA,
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
}
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
typename B2bConv2d::ElementA,
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_1,
|
||||
tensor_D0_reference.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1_reference.device_ref(),
|
||||
alpha1,
|
||||
beta1);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
CHECK_TRUE(result == cudaSuccess);
|
||||
|
||||
// sync host (copy device data to host) for dumping error output in case of mismatches
|
||||
tensor_D0_reference.sync_host();
|
||||
tensor_D1_reference.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_D1_computed.host_view(),
|
||||
tensor_D1_reference.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
|
||||
if (!passed) {
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bImplicitGemm_device_interleaved_fused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream results(fname.str());
|
||||
|
||||
results << problem_size_0 << std::endl;
|
||||
results << problem_size_1 << std::endl;
|
||||
|
||||
results
|
||||
<< "\nA0:\n" << tensor_A0.host_view() << "\n"
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -38,6 +44,8 @@
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -115,7 +123,9 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
@ -232,6 +242,14 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
status = gemm_op_1.initialize(arguments_1);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = gemm_op_0();
|
||||
CUTLASS_CHECK(status);
|
||||
status = gemm_op_1();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
@ -242,14 +260,14 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -261,13 +279,14 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
|
||||
std::cout << "total time " << totalTime / 100.0 << " ms\n";
|
||||
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
||||
std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
@ -302,7 +321,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
tensor_D0.device_ref(),
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
@ -322,7 +341,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
@ -348,7 +367,6 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
@ -420,7 +438,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
@ -442,10 +462,6 @@ struct B2bInterleavedFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
// cutlass::HostTensor<
|
||||
// typename B2bGemm::ElementC,
|
||||
// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
@ -478,7 +494,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<B2bGemm::InstructionShape::kK>(
|
||||
cutlass::reorder_column<16>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
||||
@ -494,7 +510,6 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
//tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
@ -522,10 +537,26 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
@ -536,7 +567,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -546,11 +577,11 @@ struct B2bInterleavedFusedGemmRun
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / 100.0 << " ms\n";
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
//tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
@ -598,7 +629,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
@ -617,14 +648,12 @@ struct B2bInterleavedFusedGemmRun
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
// << "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -40,6 +46,7 @@
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -102,6 +109,8 @@ template <
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
@ -115,9 +124,7 @@ template <
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Whether Beta is zero or not
|
||||
bool IsBetaZero = false>
|
||||
ElementAccumulator_>::Operator>
|
||||
class B2bGemm {
|
||||
public:
|
||||
|
||||
@ -148,7 +155,6 @@ class B2bGemm {
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool const kIsBetaZero = IsBetaZero;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
@ -176,7 +182,7 @@ class B2bGemm {
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
kIsBetaZero
|
||||
SmemAccumulator
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
@ -394,14 +400,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<B2bGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
@ -422,7 +420,7 @@ public:
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
@ -0,0 +1,300 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Template for device-level Implicit GEMM
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_sm75.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_sm80.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h"
|
||||
#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace device {
|
||||
|
||||
template<typename B2bImplicitGemmKernel_>
|
||||
class B2bImplicitGemmConvolution {
|
||||
public:
|
||||
|
||||
using B2bImplicitGemmKernel = B2bImplicitGemmKernel_;
|
||||
|
||||
using ElementA = typename B2bImplicitGemmKernel::ElementA;
|
||||
using LayoutA = typename B2bImplicitGemmKernel::LayoutA;
|
||||
using ElementB = typename B2bImplicitGemmKernel::ElementB;
|
||||
using LayoutB = typename B2bImplicitGemmKernel::LayoutB;
|
||||
using ElementC = typename B2bImplicitGemmKernel::ElementC;
|
||||
using LayoutC = typename B2bImplicitGemmKernel::LayoutC;
|
||||
using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator;
|
||||
using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute;
|
||||
using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias;
|
||||
using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias;
|
||||
using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass;
|
||||
using ArchTag = typename B2bImplicitGemmKernel::ArchTag;
|
||||
using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0;
|
||||
using ThreadblockShape1 = typename B2bImplicitGemmKernel::ThreadblockShape1;
|
||||
using WarpShape0 = typename B2bImplicitGemmKernel::WarpShape0;
|
||||
using WarpShape1 = typename B2bImplicitGemmKernel::WarpShape1;
|
||||
using InstructionShape = typename B2bImplicitGemmKernel::InstructionShape;
|
||||
using ThreadblockSwizzle = typename B2bImplicitGemmKernel::ThreadblockSwizzle;
|
||||
using EpilogueOutputOp0 = typename B2bImplicitGemmKernel::EpilogueOutputOp0;
|
||||
using EpilogueOutputOp1 = typename B2bImplicitGemmKernel::EpilogueOutputOp1;
|
||||
static int const kStages = B2bImplicitGemmKernel::kStages;
|
||||
static int const kConvDim = B2bImplicitGemmKernel::kConvDim;
|
||||
using WarpMmaOperator0 = typename B2bImplicitGemmKernel::WarpMmaOperator0;
|
||||
using WarpMmaOperator1 = typename B2bImplicitGemmKernel::WarpMmaOperator1;
|
||||
using ArchMmaOperator = typename B2bImplicitGemmKernel::ArchMmaOperator;
|
||||
using MathOperator = typename B2bImplicitGemmKernel::MathOperator;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = B2bImplicitGemmKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = B2bImplicitGemmKernel::kIteratorAlgorithm;
|
||||
|
||||
static int const kWarpCount =
|
||||
(ThreadblockShape0::kM / WarpShape0::kM) *
|
||||
(ThreadblockShape0::kN / WarpShape0::kN);
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename B2bImplicitGemmKernel::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename B2bImplicitGemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs Implicit GEMM
|
||||
B2bImplicitGemmConvolution() { }
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
// dispatch to iterators
|
||||
Status status = B2bImplicitGemmKernel::B2bMma::IteratorA0::can_implement(args.problem_size_0);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = B2bImplicitGemmKernel::B2bMma::IteratorB0::can_implement(args.problem_size_0);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = B2bImplicitGemmKernel::B2bMma::IteratorB1::can_implement(args.problem_size_1);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(
|
||||
threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0),
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.problem_size_0.split_k_slices));
|
||||
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// Determine if fusion sizes are valid
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0);
|
||||
cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1);
|
||||
|
||||
if(problem_size_0.m() != problem_size_1.m())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() != problem_size_1.k())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() > ThreadblockShape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_1.n() > ThreadblockShape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0),
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.problem_size_0.split_k_slices);
|
||||
|
||||
if(args.split_k_mode == SplitKMode::kParallel) {
|
||||
|
||||
// Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
|
||||
// The user needs to call a reduction operator to optain the final output tensor
|
||||
workspace_bytes =
|
||||
sizeof(ElementAccumulator) *
|
||||
size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) *
|
||||
size_t(grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size_0.split_k_slices > 1) {
|
||||
|
||||
// Split-K serial: The user workspace is used to store semaphore and serialize writing the
|
||||
// final reduced output to user's output tensor
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
if (args.problem_size_0.split_k_slices > 1) {
|
||||
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
|
||||
|
||||
if (status != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
// initialize the params structure from the arguments
|
||||
params_ = typename B2bImplicitGemmKernel::Params(
|
||||
args,
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<B2bImplicitGemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
// update the params structure from the arguments
|
||||
params_.ptr_A0 = args.ref_A0.data();
|
||||
params_.ptr_B0 = args.ref_B0.data();
|
||||
params_.ptr_C0 = args.ref_C0.data();
|
||||
params_.ptr_Scale0 = args.ref_Scale0.data();
|
||||
params_.ptr_Bias0 = args.ref_Bias0.data();
|
||||
params_.ptr_B1 = args.ref_B1.data();
|
||||
params_.ptr_C1 = args.ref_C1.data();
|
||||
params_.ptr_D1 = args.ref_D1.data();
|
||||
params_.output_op_0 = args.output_op_0;
|
||||
params_.output_op_1 = args.output_op_1;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(32 * kWarpCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage));
|
||||
|
||||
cutlass::Kernel<B2bImplicitGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace device
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
233
examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu
Normal file
233
examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu
Normal file
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm75_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv f16 RF residency");
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm75_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv f16 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
230
examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu
Normal file
230
examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu
Normal file
@ -0,0 +1,230 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() {
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
return true;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv f16 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() {
|
||||
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bFusedConv2dRun<B2bConv2dFprop> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_f16_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_f16_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv f16 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
235
examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu
Normal file
235
examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu
Normal file
@ -0,0 +1,235 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 RF residency");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,235 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm75,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm75_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "conv int8 shmem staging");
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
233
examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu
Normal file
233
examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu
Normal file
@ -0,0 +1,233 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{128, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 128} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,234 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "device/b2b_implicit_gemm_convolution.h"
|
||||
#include "b2b_interleaved_conv2d_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{64, 3, 3, 64}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 64} // output size (NPQK)
|
||||
);
|
||||
cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 (
|
||||
{32, 56, 56, 64}, // input size (NHWC)
|
||||
{256, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1}, // dilation (dilation_h, dilation_w)
|
||||
{32, 56, 56, 256} // output size (NPQK)
|
||||
);
|
||||
|
||||
bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel0>;
|
||||
|
||||
using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel1>;
|
||||
|
||||
B2bInterleavedNonFusedConv2dRun<Conv2dFprop0, Conv2dFprop1, 32> nonFusedConv2d;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n";
|
||||
bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() {
|
||||
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
using ElementC = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
ElementA, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementB, cutlass::layout::TensorCxRSKx<32>,
|
||||
ElementC, cutlass::layout::TensorNCxHWx<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
SmemAccumulator
|
||||
>::Kernel;
|
||||
|
||||
using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution<B2bConv2dFpropKernel>;
|
||||
|
||||
B2bInterleavedFusedConv2dRun<B2bConv2dFprop, 32> fusedConv2d;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n";
|
||||
bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial,
|
||||
alpha0, beta0, alpha1, beta1);
|
||||
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_conv2d_fprop_optimized_s8_sm80,
|
||||
&run_fused_conv2d_fprop_optimized_s8_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "conv int8 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,28 +42,28 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_nonfused_gemm_f16() {
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -79,7 +83,8 @@ void run_nonfused_gemm_f16() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -110,29 +115,30 @@ void run_nonfused_gemm_f16() {
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_gemm_f16() {
|
||||
|
||||
bool run_fused_gemm_f16_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
@ -141,7 +147,8 @@ void run_fused_gemm_f16() {
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -152,8 +159,6 @@ void run_fused_gemm_f16() {
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
@ -177,14 +182,26 @@ void run_fused_gemm_f16() {
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF Residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif //#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16,
|
||||
&run_fused_gemm_f16_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,211 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_shmem() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16,
|
||||
&run_fused_gemm_f16_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
212
examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu
Normal file
212
examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu
Normal file
@ -0,0 +1,212 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,214 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_f16_sm80() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_f16_sm80_shmem() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_f16_sm80,
|
||||
&run_fused_gemm_f16_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm f16 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,33 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -38,19 +42,19 @@
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_nonfused_gemm_s8() {
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
@ -58,8 +62,8 @@ void run_nonfused_gemm_s8() {
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -79,7 +83,8 @@ void run_nonfused_gemm_s8() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -110,30 +115,31 @@ void run_nonfused_gemm_s8() {
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_fused_gemm_s8() {
|
||||
|
||||
bool run_fused_gemm_s8_rf_res() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -141,7 +147,8 @@ void run_fused_gemm_s8() {
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
@ -152,8 +159,6 @@ void run_fused_gemm_s8() {
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
@ -177,14 +182,28 @@ void run_fused_gemm_s8() {
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF Residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
return passed;
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,211 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_shmem() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
SmemAccumulator
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8,
|
||||
&run_fused_gemm_s8_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm s8 shmem staing");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
226
examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu
Normal file
226
examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu
Normal file
@ -0,0 +1,226 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 128, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,225 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 256, 64);
|
||||
|
||||
bool run_nonfused_gemm_s8_sm80() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80_shmem() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_shmem
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 shmem staging");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -66,6 +72,7 @@ struct B2bGemm {
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename B2bMma::IteratorA0::Params params_A0;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::Params params_B0;
|
||||
@ -91,7 +98,7 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -112,6 +119,7 @@ struct B2bGemm {
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A0(ref_A0.layout()),
|
||||
ref_A0(ref_A0),
|
||||
params_B0(ref_B0.layout()),
|
||||
@ -200,6 +208,19 @@ struct B2bGemm {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// Determine if fusion sizes are valid
|
||||
if(problem_size_0.m() != problem_size_1.m())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() != problem_size_1.k())
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -210,7 +231,8 @@ struct B2bGemm {
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
@ -313,7 +335,8 @@ struct B2bGemm {
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
@ -333,7 +356,7 @@ struct B2bGemm {
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k());
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
@ -0,0 +1,521 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined Implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
|
||||
>
|
||||
struct B2bImplicitGemmConvolution {
|
||||
|
||||
using B2bMma = B2bMma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp0 = typename B2bMma::OutputOp;
|
||||
using EpilogueOutputOp1 = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static Operator const kConvolutionalOperator = ConvOperator;
|
||||
|
||||
using ElementA = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB = typename B2bMma::IteratorB0::Layout;
|
||||
using ElementC = typename EpilogueOutputOp1::ElementOutput;
|
||||
|
||||
/// Set output tensor C layout
|
||||
using LayoutC = LayoutA;
|
||||
|
||||
using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator;
|
||||
using ElementCompute = typename EpilogueOutputOp0::ElementCompute;
|
||||
|
||||
/// Scale and Bias
|
||||
using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||
using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout;
|
||||
|
||||
using WarpMmaOperator0 = typename B2bMma::Policy0::Operator;
|
||||
using WarpMmaOperator1 = typename B2bMma::Policy1::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator0::ArchMmaOperator;
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
using OperatorClass = typename WarpMmaOperator0::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator0::ArchTag;
|
||||
|
||||
using ThreadblockShape0 = typename B2bMma::Shape0;
|
||||
using ThreadblockShape1 = typename B2bMma::Shape1;
|
||||
using WarpShape0 = typename WarpMmaOperator0::Shape;
|
||||
using WarpShape1 = typename WarpMmaOperator1::Shape;
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
static int const kStages = B2bMma::kStages;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = B2bMma::IteratorA0::kIteratorAlgorithm;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||
static int const kThreadCount = 32 * WarpCount0::kCount;
|
||||
|
||||
using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef;
|
||||
using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef;
|
||||
using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef;
|
||||
using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef;
|
||||
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
||||
|
||||
/// Check iterator A and B convolution dimension are the same and
|
||||
// set device::B2bImplicitGemmConvolution::kConvDim
|
||||
static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim,
|
||||
"Convolution on different dimensions is not supported");
|
||||
static int const kConvDim = B2bMma::IteratorA0::kConvDim;
|
||||
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
|
||||
LayoutC,
|
||||
typename Epilogue::OutputTileIterator::Layout,
|
||||
TensorRefC,
|
||||
ConvOperator,
|
||||
ConvProblemSize
|
||||
>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ConvProblemSize problem_size_0;
|
||||
ConvProblemSize problem_size_1;
|
||||
TensorRefA0 ref_A0;
|
||||
TensorRefB0 ref_B0;
|
||||
TensorRefC ref_C0;
|
||||
TensorRefScaleBias0 ref_Scale0;
|
||||
TensorRefScaleBias0 ref_Bias0;
|
||||
TensorRefB1 ref_B1;
|
||||
TensorRefC ref_C1;
|
||||
TensorRefC ref_D1;
|
||||
typename EpilogueOutputOp0::Params output_op_0;
|
||||
typename EpilogueOutputOp1::Params output_op_1;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size_0,
|
||||
ConvProblemSize const & problem_size_1
|
||||
):
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size_0,
|
||||
ConvProblemSize const & problem_size_1,
|
||||
TensorRefA0 const & ref_A0,
|
||||
TensorRefB0 const & ref_B0,
|
||||
TensorRefC const & ref_C0,
|
||||
TensorRefScaleBias0 const & ref_Scale0,
|
||||
TensorRefScaleBias0 const & ref_Bias0,
|
||||
TensorRefB1 const & ref_B1,
|
||||
TensorRefC const & ref_C1,
|
||||
TensorRefC const & ref_D1,
|
||||
typename EpilogueOutputOp0::Params const & output_op_0,
|
||||
typename EpilogueOutputOp1::Params const & output_op_1,
|
||||
SplitKMode const & split_k_mode = SplitKMode::kSerial
|
||||
):
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
ref_A0(ref_A0),
|
||||
ref_B0(ref_B0),
|
||||
ref_C0(ref_C0),
|
||||
ref_Scale0(ref_Scale0),
|
||||
ref_Bias0(ref_Bias0),
|
||||
ref_B1(ref_B1),
|
||||
ref_C1(ref_C1),
|
||||
ref_D1(ref_D1),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1),
|
||||
split_k_mode(split_k_mode)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
ConvProblemSize problem_size_0;
|
||||
ConvProblemSize problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
gemm::GemmCoord implicit_gemm_problem_size_0;
|
||||
gemm::GemmCoord implicit_gemm_problem_size_1;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_iterations_1;
|
||||
typename B2bMma::IteratorA0::Params iterator_A0;
|
||||
typename B2bMma::IteratorA0::Element const *ptr_A0;
|
||||
typename B2bMma::IteratorB0::Params iterator_B0;
|
||||
typename B2bMma::IteratorB0::Element const *ptr_B0;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C0;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0;
|
||||
typename B2bMma::IteratorB1::Params iterator_B1;
|
||||
typename B2bMma::IteratorB1::Element const *ptr_B1;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C1;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C1;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_D1;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_D1;
|
||||
typename EpilogueOutputOp0::Params output_op_0;
|
||||
typename EpilogueOutputOp1::Params output_op_1;
|
||||
int *semaphore;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Arguments const &args,
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size_0(args.problem_size_0),
|
||||
problem_size_1(args.problem_size_1),
|
||||
implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)),
|
||||
implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)),
|
||||
iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())),
|
||||
ptr_A0(args.ref_A0.data()),
|
||||
iterator_B0(args.problem_size_0, args.ref_B0.layout()),
|
||||
ptr_B0(args.ref_B0.data()),
|
||||
iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)),
|
||||
ptr_C0(args.ref_C0.data()),
|
||||
ptr_Scale0(args.ref_Scale0.data()),
|
||||
ptr_Bias0(args.ref_Bias0.data()),
|
||||
iterator_B1(args.problem_size_1, args.ref_B1.layout()),
|
||||
ptr_B1(args.ref_B1.data()),
|
||||
iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)),
|
||||
ptr_C1(args.ref_C1.data()),
|
||||
iterator_D1(ConvOutputIteratorParameter::layout(args.ref_D1)),
|
||||
ptr_D1(args.ref_D1.data()),
|
||||
output_op_0(args.output_op_0),
|
||||
output_op_1(args.output_op_1),
|
||||
semaphore(semaphore),
|
||||
split_k_mode(args.split_k_mode)
|
||||
{
|
||||
gemm_k_iterations_0 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape0::kK, args.problem_size_0);
|
||||
gemm_k_iterations_1 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape1::kK, args.problem_size_1);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
implicit_gemm_problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.problem_size_0.split_k_slices);
|
||||
|
||||
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename B2bMma::B2bMmaSharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bImplicitGemmConvolution() { }
|
||||
|
||||
/// Executes one ImplicitGEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename B2bMma::IteratorA0 iterator_A0(
|
||||
params.iterator_A0,
|
||||
params.problem_size_0,
|
||||
params.ptr_A0,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * B2bMma::Shape0::kM,
|
||||
threadblock_tile_idx.k() * B2bMma::Shape0::kK
|
||||
)
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorB0 iterator_B0(
|
||||
params.iterator_B0,
|
||||
params.problem_size_0,
|
||||
params.ptr_B0,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.k() * B2bMma::Shape0::kK,
|
||||
threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorB1 iterator_B1(
|
||||
params.iterator_B1,
|
||||
params.problem_size_1,
|
||||
params.ptr_B1,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.k() * B2bMma::Shape1::kK,
|
||||
threadblock_tile_idx.n() * B2bMma::Shape1::kN
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ptr_Scale0,
|
||||
{1, params.problem_size_0.K},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ptr_Bias0,
|
||||
{1, params.problem_size_0.K},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
EpilogueOutputOp0 output_op_0(params.output_op_0);
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename B2bMma::FragmentC0 src_accum;
|
||||
typename B2bMma::FragmentC1 accumulators;
|
||||
|
||||
src_accum.clear();
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp1 output_op_1(params.output_op_1);
|
||||
|
||||
// Construct the semaphore.
|
||||
int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m();
|
||||
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_idx.m() * B2bMma::Shape1::kM,
|
||||
threadblock_tile_idx.n() * B2bMma::Shape1::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor
|
||||
typename Epilogue::OutputTileIterator iterator_D1(
|
||||
params.iterator_D1,
|
||||
params.ptr_D1,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size_1),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator reading from source accumulator tensor
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.iterator_C1,
|
||||
params.ptr_C1,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size_1),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_idx.k()) {
|
||||
iterator_C1 = iterator_D1;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_idx.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
// Each split-k-slice writes to a unique tensor location
|
||||
else if (params.split_k_mode == SplitKMode::kParallel) {
|
||||
iterator_D1.add_pointer_offset(threadblock_tile_idx.k() *
|
||||
cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size_1));
|
||||
}
|
||||
|
||||
// Run efficient epilogue
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_idx.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,94 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_pipelined.h"
|
||||
#include "threadblock/b2b_implicit_gemm_multistage.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv2dFprop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
bool SmemAccumulator = false
|
||||
> struct DefaultB2bConv2dFprop;
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,749 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_pipelined.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
false
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelined<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,740 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_multistage.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistage<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
ThreadblockShape1,
|
||||
FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorA1ScaleBias,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,817 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
|
||||
/// and 2 stage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,804 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "kernel/default_b2b_conv2d_fprop.h"
|
||||
#include "kernel/b2b_implicit_gemm_convolution.h"
|
||||
#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape0,
|
||||
typename ThreadblockShape1,
|
||||
typename WarpShape0,
|
||||
typename WarpShape1,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp0,
|
||||
typename EpilogueOutputOp1,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
int InterleavedK
|
||||
>
|
||||
struct DefaultB2bConv2dFprop <
|
||||
ElementA,
|
||||
layout::TensorNCxHWx<InterleavedK>,
|
||||
ElementB,
|
||||
layout::TensorCxRSKx<InterleavedK>,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
true
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
|
||||
using IteratorA0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, layout::TensorNCxHWx<InterleavedK>,
|
||||
ThreadMapA0
|
||||
>;
|
||||
|
||||
using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
// Note GEMM shared memory threadmap is used here because conv global memory
|
||||
// layout needs to be mapped to fprop which is similar to the crosswise
|
||||
// layout which is used by the interleaved GEMM shared memory threadmap.
|
||||
// The Interleaved GEMM global memory layout is similar to the congruous
|
||||
// layout.
|
||||
using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
|
||||
using IteratorB0 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB0
|
||||
>;
|
||||
|
||||
using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, layout::TensorCxRSKx<InterleavedK>,
|
||||
ThreadMapB1
|
||||
>;
|
||||
|
||||
using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
|
||||
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
ElementC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the Mma
|
||||
using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
|
||||
ThreadblockShape0,
|
||||
IteratorA0,
|
||||
SmemIteratorA0,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB0,
|
||||
SmemIteratorB0,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator,
|
||||
SmemIteratorD0,
|
||||
ThreadblockShape1,
|
||||
WarpIteratorA1,
|
||||
IteratorB1,
|
||||
SmemIteratorB1,
|
||||
arch::CacheOperation::Global,
|
||||
EpilogueOutputOp0,
|
||||
MmaPolicy0,
|
||||
MmaPolicy1,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
|
||||
ThreadblockShape1,
|
||||
WarpMmaTensorOp1,
|
||||
1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount,
|
||||
InterleavedK
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
|
||||
B2bMma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,29 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
*modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
*notice, this list of conditions and the following disclaimer in the
|
||||
*documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its
|
||||
*contributors may be used to endorse or promote products derived from this
|
||||
*software without specific prior written permission.
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
|
||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
|
||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
|
||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
@ -113,11 +118,80 @@ template <
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Beta is zero or not
|
||||
bool IsBetaZero = false
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
@ -217,7 +291,82 @@ struct DefaultB2bGemm<
|
||||
};
|
||||
|
||||
|
||||
/// Partial specialization for Turing IMMA Interleaved layout
|
||||
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0,
|
||||
true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
@ -251,9 +400,7 @@ template <
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Is Beta zero or not
|
||||
bool IsBetaZero>
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
@ -261,7 +408,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -280,8 +427,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK,
|
||||
IsBetaZero>::Epilogue;
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
@ -0,0 +1,397 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "threadblock/default_b2b_mma.h"
|
||||
#include "threadblock/default_b2b_mma_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator, true> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
true
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator,
|
||||
EpilogueOutputOp0,
|
||||
false,
|
||||
true
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
typename B2bMma::Operator1,
|
||||
kPartitionsK1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0,
|
||||
true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Partial specialization for Turing Integer Tensor Core Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue for the 2nd Gemm
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
95
examples/13_two_tensor_op_fusion/test_run.h
Normal file
95
examples/13_two_tensor_op_fusion/test_run.h
Normal file
@ -0,0 +1,95 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// Run tests on GPUs
|
||||
|
||||
int testRun(int arch, std::vector<bool (*)()> & test_funcs, const std::string & test_name) {
|
||||
|
||||
bool supported = false;
|
||||
|
||||
int arch_major = arch / 10;
|
||||
int arch_minor = arch - arch / 10 * 10;
|
||||
|
||||
if(arch_major >= 8) {
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
else if(arch_major >= 7) {
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) {
|
||||
supported = true;
|
||||
}
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major == arch_major && props.minor == arch_minor)) {
|
||||
supported = false;
|
||||
}
|
||||
|
||||
if (!supported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
std::cout << "This example isn't supported on current architecture" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "Device: " << props.name << std::endl;
|
||||
std::cout << "Arch: SM" << arch << std::endl;
|
||||
std::cout << "Test: " << test_name << std::endl;
|
||||
for(auto func : test_funcs) {
|
||||
pass &= func();
|
||||
}
|
||||
|
||||
|
||||
if(pass)
|
||||
return 0;
|
||||
else
|
||||
return -1;
|
||||
|
||||
}
|
||||
|
||||
@ -0,0 +1,832 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bImplicitGemmMultistage :
|
||||
public gemm::threadblock::B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
using ElementC = typename Policy0::Operator::ElementC;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
using WarpLoadedFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bImplicitGemmMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(
|
||||
IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
|
||||
iterator_A0.set_iteration_index(group_start_A0);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
|
||||
if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0);
|
||||
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(
|
||||
IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
|
||||
iterator_B1.set_iteration_index(group_start_B1);
|
||||
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0) {
|
||||
group_start_iteration_A0 = 0;
|
||||
group_start_iteration_B0 = 0;
|
||||
} else {
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// 2nd Implicit Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0],
|
||||
warp_loaded_frag_A1_scale[0],
|
||||
warp_loaded_frag_A1_bias[0],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_1(iterator_B1);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load threadblock-level scale/bias vector from global memory
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
// Load warp-level scale bias fragment from threadblock scale/bias vector
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_B1;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
group_start_iteration_B1 = 0;
|
||||
} else {
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1,
|
||||
group_start_iteration_B1);
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,816 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bImplicitGemmMultistageSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
using ElementC = typename Policy0::Operator::ElementC;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bImplicitGemmMultistageSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(
|
||||
IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
|
||||
iterator_A0.set_iteration_index(group_start_A0);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
|
||||
if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0);
|
||||
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(
|
||||
IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
|
||||
iterator_B1.set_iteration_index(group_start_B1);
|
||||
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0) {
|
||||
group_start_iteration_A0 = 0;
|
||||
group_start_iteration_B0 = 0;
|
||||
} else {
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.advance();
|
||||
iterator_B0.advance();
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// 2nd Implicit Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance_1(iterator_B1);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_B1;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
group_start_iteration_B1 = 0;
|
||||
} else {
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
}
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1,
|
||||
group_start_iteration_B1);
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.advance();
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,553 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// FragmentIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
// (concept: VectorFragmentIterator)
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
IteratorB1_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bImplicitGemmPipelined :
|
||||
public gemm::threadblock::B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2>;
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over tiles of A1 operand from accumulator tile
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using FragmentIteratorA1ScaleBias =
|
||||
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA0 = typename IteratorA0::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB0 = typename IteratorB0::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A0 operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B0 operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
/// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment
|
||||
using WarpFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bImplicitGemmPipelined(
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
|
||||
|
||||
//These may change across different GEMM layers
|
||||
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k;
|
||||
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n});
|
||||
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
FragmentA0 tb_frag_A;
|
||||
FragmentB0 tb_frag_B0;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B0.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA0 warp_frag_A0[2];
|
||||
WarpFragmentB0 warp_frag_B0[2];
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//2nd Implicit Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
if(PerChannelScale)
|
||||
tb_frag_A1_scale.clear();
|
||||
tb_frag_A1_bias.clear();
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
if(PerChannelScale)
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
|
||||
if(PerChannelScale)
|
||||
++iterator_A1_scale;
|
||||
++iterator_A1_bias;
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_scale[2];
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_bias[2];
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0],
|
||||
warp_frag_A1_bias[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,535 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A0 operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B0 operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B1 operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
IteratorB1_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bImplicitGemmPipelinedSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2>;
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA0 = typename IteratorA0::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB0 = typename IteratorB0::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A0 operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B0 operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bImplicitGemmPipelinedSmemAccumulator(
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumulator tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
FragmentA0 tb_frag_A;
|
||||
FragmentB0 tb_frag_B0;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B0.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA0 warp_frag_A0[2];
|
||||
WarpFragmentB0 warp_frag_B0[2];
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/// 2nd Implicit Gemm
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1)
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -180,8 +186,8 @@ class B2bMmaBase {
|
||||
using SharedStorage0 = SharedStorage<Shape0, Policy0>;
|
||||
using SharedStorage1 = SharedStorage<Shape1, Policy1>;
|
||||
union B2bMmaSharedStorage {
|
||||
SharedStorage0 sharedStorage0;
|
||||
SharedStorage1 sharedStorage1;
|
||||
SharedStorage0 shared_storage0;
|
||||
SharedStorage1 shared_storage1;
|
||||
};
|
||||
|
||||
|
||||
@ -197,7 +203,7 @@ class B2bMmaBase {
|
||||
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
|
||||
typename Operator0::IteratorB warp_tile_iterator_B0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
|
||||
/// Iterator to load a warp-scoped tile of B1 operand from shared memory
|
||||
typename Operator1::IteratorB warp_tile_iterator_B1_;
|
||||
|
||||
public:
|
||||
@ -214,9 +220,9 @@ public:
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), lane_idx),
|
||||
warp_tile_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), lane_idx) {
|
||||
warp_tile_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), lane_idx),
|
||||
warp_tile_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), lane_idx) {
|
||||
|
||||
}
|
||||
};
|
||||
@ -0,0 +1,179 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "threadblock/b2b_mma_base.h"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Shared Memory Accumulator Iterator
|
||||
typename SmemAccumulatorIterator0_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaBaseSmemAccumulator :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages> {
|
||||
|
||||
public:
|
||||
///< Base class
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
using Shape1 = Shape1_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
|
||||
using SmemAccumulatorIterator0 = SmemAccumulatorIterator0_;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
/// Shared storage object needed by accumulator
|
||||
template<
|
||||
typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename Padding_
|
||||
>
|
||||
class AccumulatorSharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using Padding = Padding_;
|
||||
|
||||
/// Tensor reference to the accumulator
|
||||
using TensorRefAccum = TensorRef<Element, Layout>;
|
||||
|
||||
/// Shape of the accumulator matrix in shared memory
|
||||
using ShapeAccum = MatrixShape<Shape::kM + Padding::kRow,
|
||||
Shape::kN + Padding::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for accumulator
|
||||
AlignedBuffer<Element, ShapeAccum::kCount> accum;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the Accum matrix
|
||||
CUTLASS_DEVICE
|
||||
static Layout LayoutAccum() {
|
||||
return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the Accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefAccum accum_ref() {
|
||||
return TensorRefAccum{accum.data(), LayoutAccum()};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
using AccumulatorSharedStorage0 = AccumulatorSharedStorage<
|
||||
Shape0, typename SmemAccumulatorIterator0::Element,
|
||||
typename SmemAccumulatorIterator0::TensorLayout,
|
||||
typename SmemAccumulatorIterator0::Padding>;
|
||||
|
||||
struct B2bMmaSharedStorage {
|
||||
typename Base::B2bMmaSharedStorage b2b_mma_shared_storage;
|
||||
AccumulatorSharedStorage0 accumulator_shared_storage0;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaBaseSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage.b2b_mma_shared_storage, thread_idx, warp_idx, lane_idx) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,817 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaMultistage :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over intermediate accumulator tile
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const TBLDGSTSIterationsA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
iterator_A0.set_iteration_index(group_start_A0 *
|
||||
IteratorA0::kAccessesPerVector);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0 *
|
||||
IteratorB0::kAccessesPerVector);
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
iterator_B1.set_iteration_index(group_start_B1 *
|
||||
IteratorB1::kAccessesPerVector);
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B1.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0)
|
||||
{
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations0 - 1) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
int group_start_iteration_B1;
|
||||
|
||||
group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
int group_start_iteration_B1;
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,860 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA0,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB0,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB1,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaMultistageSmemAccumulator :
|
||||
public gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations0 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
static_assert(Base::kWarpGemmIterations1 > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const TBLDGSTSIterationsA0 =
|
||||
IteratorA0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB0 =
|
||||
IteratorB0::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const TBLDGSTSIterationsB1 =
|
||||
IteratorB1::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA0 =
|
||||
(TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB0 =
|
||||
(TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB1 =
|
||||
(TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaMultistageSmemAccumulator(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset(
|
||||
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0,
|
||||
int group_start_A0 = 0, int group_start_B0 = 0) {
|
||||
iterator_A0.set_iteration_index(group_start_A0 *
|
||||
IteratorA0::kAccessesPerVector);
|
||||
this->smem_iterator_A0_.set_iteration_index(group_start_A0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) {
|
||||
if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(group_start_B0 *
|
||||
IteratorB0::kAccessesPerVector);
|
||||
this->smem_iterator_B0_.set_iteration_index(group_start_B0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) {
|
||||
if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B0.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(IteratorB1 &iterator_B1,
|
||||
int group_start_B1 = 0) {
|
||||
iterator_B1.set_iteration_index(group_start_B1 *
|
||||
IteratorB1::kAccessesPerVector);
|
||||
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
|
||||
if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B1.get();
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
///< epilogue operation after 1st Gemm
|
||||
OutputOp output_op_0)
|
||||
{
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) {
|
||||
typename IteratorA0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA0::AccessType *>(
|
||||
this->smem_iterator_A0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA0::Element>::value *
|
||||
IteratorA0::ThreadMap::kElementsPerAccess /
|
||||
IteratorA0::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA0>(
|
||||
dst_ptr + v, iterator_A0.get(), iterator_A0.valid());
|
||||
|
||||
++iterator_A0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A0_;
|
||||
}
|
||||
|
||||
iterator_B0.set_iteration_index(0);
|
||||
this->smem_iterator_B0_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) {
|
||||
typename IteratorB0::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB0::AccessType *>(
|
||||
this->smem_iterator_B0_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB0::Element>::value *
|
||||
IteratorB0::ThreadMap::kElementsPerAccess /
|
||||
IteratorB0::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB0>(
|
||||
dst_ptr + v, iterator_B0.get(), iterator_B0.valid());
|
||||
|
||||
++iterator_B0;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA0 warp_loaded_frag_A0[2];
|
||||
WarpLoadedFragmentB0 warp_loaded_frag_B0[2];
|
||||
WarpTransformedFragmentA0 warp_transformed_frag_A0[2];
|
||||
WarpTransformedFragmentB0 warp_transformed_frag_B0[2];
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0],
|
||||
warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > (-Base::kStages + 1);) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
warp_loaded_frag_A0[warp_mma_k % 2],
|
||||
warp_loaded_frag_B0[warp_mma_k % 2]);
|
||||
|
||||
warp_mma0(
|
||||
accum0,
|
||||
warp_transformed_frag_A0[warp_mma_k % 2],
|
||||
warp_transformed_frag_B0[warp_mma_k % 2],
|
||||
accum0
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations0 - 1) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
|
||||
group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations0) {
|
||||
int group_start_iteration_A0, group_start_iteration_B0;
|
||||
group_start_iteration_A0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupA0;
|
||||
group_start_iteration_B0 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB0;
|
||||
|
||||
copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0,
|
||||
group_start_iteration_B0);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A0.add_tile_offset({0, 1});
|
||||
iterator_B0.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A0_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B0_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK *
|
||||
Base::kWarpGemmIterations0,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations0)
|
||||
warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A0[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// 2nd Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
|
||||
// LDGSTS for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
|
||||
typename IteratorB1::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB1::AccessType *>(
|
||||
this->smem_iterator_B1_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB1::Element>::value *
|
||||
IteratorB1::ThreadMap::kElementsPerAccess /
|
||||
IteratorB1::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
|
||||
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
|
||||
|
||||
++iterator_B1;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// DEPBAR+SYNC
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0],
|
||||
warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1);
|
||||
gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
warp_loaded_frag_A1[warp_mma_k % 2],
|
||||
warp_loaded_frag_B1[warp_mma_k % 2]);
|
||||
|
||||
|
||||
warp_mma1(
|
||||
accum,
|
||||
warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
|
||||
int group_start_iteration_B1;
|
||||
|
||||
group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
|
||||
int group_start_iteration_B1;
|
||||
group_start_iteration_B1 =
|
||||
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
|
||||
|
||||
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages have committed.
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_B1.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_B1_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
// we can start right away on mma instructions
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,24 +1,30 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
@ -48,10 +54,6 @@ namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<int a>
|
||||
struct chk_val {
|
||||
static_assert(a==0, "check value");
|
||||
};
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
@ -110,7 +112,8 @@ template <
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bMmaPipelined : public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
class B2bMmaPipelined :
|
||||
public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
@ -134,7 +137,7 @@ public:
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
@ -156,7 +159,7 @@ public:
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
@ -165,7 +168,7 @@ public:
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
@ -178,7 +181,7 @@ public:
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
@ -211,9 +214,9 @@ public:
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) {
|
||||
smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
|
||||
@ -241,16 +244,16 @@ public:
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
@ -272,8 +275,8 @@ public:
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
@ -298,23 +301,19 @@ public:
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_0 <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::WarpGemmIterations == 2.
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
@ -328,19 +327,14 @@ public:
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
++this->smem_iterator_A_;
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
@ -369,19 +363,18 @@ public:
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_0 <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -403,7 +396,7 @@ public:
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
@ -413,7 +406,6 @@ public:
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
//warp_tile_iterator_A1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
||||
@ -429,9 +421,7 @@ public:
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_1 <= 1) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -454,15 +444,14 @@ public:
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
++smem_iterator_B1_;
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
@ -475,11 +464,10 @@ public:
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
@ -488,17 +476,14 @@ public:
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
++iterator_B1;
|
||||
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_1 <= 2) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 2);
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@ -0,0 +1,541 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base_smem_accumulator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// Iterates over accumulator tile
|
||||
typename FragmentIteratorAccumulator_,
|
||||
/// Iterates over accumulator tile in shared memory
|
||||
typename SmemIteratorD0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A0 operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B0 operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B1 operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
IteratorB1_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bMmaPipelinedSmemAccumulator :
|
||||
public B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = B2bMmaBaseSmemAccumulator<Shape0_, Shape1_, Policy0_, Policy1_, SmemIteratorD0_, 2>;
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory
|
||||
|
||||
using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA0 = typename IteratorA0::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB0 = typename IteratorB0::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A0 operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B0 operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
/// Epilog in shared memory
|
||||
using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator<
|
||||
SmemIteratorD0, ///< SmemTileIterator
|
||||
FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator
|
||||
IteratorAccumulatorScaleBias, ///< ScaleBiasIterator
|
||||
OutputOp>; ///< Output operator
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
using WarpFragmentA1 = typename Operator1::FragmentA;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Shared Memory Iterator to store accumulator tile
|
||||
SmemIteratorD0 smem_iterator_D0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
WarpIteratorA1 warp_tile_iterator_A1_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaPipelinedSmemAccumulator(
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx),
|
||||
smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM;
|
||||
int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM;
|
||||
|
||||
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0;
|
||||
|
||||
int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
|
||||
|
||||
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
|
||||
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
|
||||
|
||||
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0});
|
||||
warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1});
|
||||
|
||||
// Add smem accumulator iterator warp offset
|
||||
smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow,
|
||||
warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn});
|
||||
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
FragmentA0 tb_frag_A;
|
||||
FragmentB0 tb_frag_B0;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B0.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA0 warp_frag_A0[2];
|
||||
WarpFragmentB0 warp_frag_B0[2];
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A0(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B0_.store(transform_B0(tb_frag_B0));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 <= 2);
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2],
|
||||
warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//2nd Gemm
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
int gemm_k_iterations_1 = Shape0::kN / Shape1::kK;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
// skip warp tile loading for the last kgroup
|
||||
if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1)
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 <= 2);
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2],
|
||||
warp_frag_B1[warp_mma_k % 2], accum);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
510
examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h
Normal file
510
examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h
Normal file
@ -0,0 +1,510 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined.h"
|
||||
#include "threadblock/b2b_mma_multistage.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Staging the accumulators in shared memory.
|
||||
bool SmemAccumulator = false>
|
||||
struct DefaultB2bMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output with 2-stage pipeline
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, false> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>,
|
||||
ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output for multi-stage
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, false> {
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA0 = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB0 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output with 2-stage pipeline
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
|
||||
static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>, ElementA,
|
||||
LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
|
||||
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
|
||||
|
||||
// Use fragment iterator for A1 operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output with multi-stage
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>;
|
||||
|
||||
// Use fragment iterator for A1 operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>;
|
||||
|
||||
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,605 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined_smem_accumulator.h"
|
||||
#include "threadblock/b2b_mma_multistage_smem_accumulator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output with 2-stage pipeline
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, false, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>,
|
||||
ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output for multi-stage
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, false, true> {
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA0 = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB0 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output with 2-stage pipeline
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, true, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
|
||||
static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>, ElementA,
|
||||
LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
|
||||
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4; //For interleaved layout
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for column-major-interleaved output with multi-stage
|
||||
/// Accumulator will be staged in shared memory.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp, true, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, Stages,
|
||||
Operator, true>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
|
||||
ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
|
||||
ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
|
||||
using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
|
||||
using MmaPolicy0 = typename MmaCore0::MmaPolicy;
|
||||
using MmaPolicy1 = typename MmaCore1::MmaPolicy;
|
||||
|
||||
// Use fragment iterator for the accumulator
|
||||
using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
|
||||
using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
WarpShape0, InstructionShape,
|
||||
ElementAccumulator,
|
||||
typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Store Accumulator tiles to Shared Memory
|
||||
using SmemIteratorD0 =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
typename EpilogueOutputOp::ElementOutput,
|
||||
SmemAccumulatorLayout
|
||||
>;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
// load warp tile from Shared Memory accumulator
|
||||
using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
|
||||
MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
|
||||
ElementA, SmemAccumulatorLayout,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
|
||||
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
IteratorAccumulatorScaleBias,
|
||||
FragmentIteratorAccumulator, SmemIteratorD0,
|
||||
typename MmaCore1::Shape, WarpIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
35
examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt
Normal file
35
examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
14_ampere_tf32_tensorop_gemm
|
||||
ampere_tf32_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,472 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
Please check example 07 and 08 for the basics of tensor op gemm kernels. On NVIDIA Ampere
|
||||
architecture, most concept still holds. The two main differences are
|
||||
|
||||
1. NVIDIA Ampere architecture introduces a new series of tensor core instructions (see
|
||||
include/cutlass/arch/mma_sm80.h) which are more efficient on Ampere.
|
||||
|
||||
2. NVIDIA Ampere architecture uses cp_async() to build multistage software pipeline to better hide
|
||||
latency (see include/cutlass/gemm/threadblock/mma_multistage.h)
|
||||
|
||||
Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h)
|
||||
data types in tensor cores. One big advantage is that we can load in fp32 data and convert them
|
||||
implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional
|
||||
fp32 data by using NVIDIA Ampere architecture.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result {
|
||||
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Result(
|
||||
double runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess
|
||||
):
|
||||
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int batch_count;
|
||||
float alpha;
|
||||
float beta;
|
||||
|
||||
bool reference_check;
|
||||
int iterations;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({5120, 4096, 4096}),
|
||||
batch_count(1),
|
||||
reference_check(true),
|
||||
iterations(20),
|
||||
alpha(1),
|
||||
beta() { }
|
||||
|
||||
bool valid() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "14_ampere_tf32_tensorop_gemm example\n\n"
|
||||
<< " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n"
|
||||
<< " --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product() * batch_count;
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = float; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = float; // <- data type of elements in input matrix A
|
||||
using ElementInputB = float; // <- data type of elements in input matrix B
|
||||
using ElementOutput = float; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 4;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
int run(Options &options) {
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size = options.problem_size;
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(4),
|
||||
ElementInputA(-4),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(4),
|
||||
ElementInputB(-4),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(4),
|
||||
ElementOutput(-4),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Result structure
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Construct events
|
||||
//
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of GEMMs
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
//
|
||||
// Run profiling loop
|
||||
//
|
||||
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Stop profiling loop
|
||||
//
|
||||
|
||||
// Record an event when the GEMMs are complete
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
cutlass::reference::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue>
|
||||
gemm_device;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_device(problem_size,
|
||||
alpha,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
beta,
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_d.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
if (passed) {
|
||||
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
printf("%d x %d x %d TF32 tensor op Matrix Multiply\n", \
|
||||
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return run(options);
|
||||
}
|
||||
35
examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt
Normal file
35
examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
15_ampere_sparse_tensorop_gemm
|
||||
ampere_sparse_tensorop_gemm.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,317 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere
|
||||
architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4.
|
||||
|
||||
Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of
|
||||
meta data is different for every data types. CUTLASS templates can automatically infer it based on
|
||||
input A and B. Check code below.
|
||||
|
||||
Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers
|
||||
efficiently.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_sparse.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/host_uncompress.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = int32_t; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = int32_t; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Row Major for
|
||||
// Matrix A, Column Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
||||
// memory access. For a byte, it's 16
|
||||
// elements. This becomes the vector width of
|
||||
// math instructions in the epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||
using ElementInputE = typename Gemm::ElementE;
|
||||
using LayoutInputE = cutlass::layout::RowMajor;
|
||||
using ReorderedLayoutInputE = typename Gemm::LayoutE;
|
||||
|
||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||
// 50% Sparsity on Ampere
|
||||
constexpr int kSparse = Gemm::kSparse;
|
||||
// How many elements of A are covered per ElementE
|
||||
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||
// The size of individual meta data
|
||||
constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
|
||||
int run() {
|
||||
|
||||
const int length_m = 512;
|
||||
const int length_n = 512;
|
||||
const int length_k = 1024;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2)
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a_uncompressed(
|
||||
problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing
|
||||
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing.
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(2),
|
||||
ElementInputA(-2),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(2),
|
||||
ElementInputB(-2),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(2),
|
||||
ElementOutput(-2),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||
tensor_e.host_view(),
|
||||
1,
|
||||
kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core
|
||||
// instructions.
|
||||
cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(),
|
||||
{problem_size.m(), problem_size.n(),
|
||||
problem_size.k() / kSparse / kElementsPerElementE});
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_e_reordered.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// uncompress tensor_a based on meta data tensor_e. We need it for reference computing.
|
||||
cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(),
|
||||
tensor_e.host_ref(), problem_size.m(), problem_size.k());
|
||||
|
||||
// Create instantiation for host reference gemm kernel
|
||||
cutlass::reference::host::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue,
|
||||
typename Gemm::Operator>
|
||||
gemm_host;
|
||||
|
||||
// Launch host reference gemm kernel
|
||||
gemm_host(problem_size,
|
||||
alpha,
|
||||
tensor_a_uncompressed.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
beta,
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref());
|
||||
|
||||
// Copy output data from CUTLASS host for comparison
|
||||
tensor_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples.
|
||||
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major * 10 + props.minor < 80) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
36
examples/16_ampere_tensorop_conv2dfprop/CMakeLists.txt
Normal file
36
examples/16_ampere_tensorop_conv2dfprop/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
16_ampere_tensorop_conv2dfprop
|
||||
ampere_tensorop_conv2dfprop.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,767 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
This example shows how to run convolution kernels using functions and data structures
|
||||
provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU.
|
||||
|
||||
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
|
||||
of GPU easily.
|
||||
|
||||
CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
|
||||
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
|
||||
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
|
||||
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
|
||||
threadblock-tile (tile size computed by a threadblock).
|
||||
|
||||
In thie example, we split variable initialization into
|
||||
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
|
||||
can view them (logical to physical mapping)
|
||||
2. Setting up computation properties : describes how the above set tensors will be used to compute
|
||||
output of convolution.
|
||||
|
||||
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
|
||||
with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS,
|
||||
the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as
|
||||
alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as
|
||||
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
|
||||
ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as
|
||||
cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float),
|
||||
ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t),
|
||||
ElementOutput (float). Communicating just the data type is not enough. As the data is laid out
|
||||
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
|
||||
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup
|
||||
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
|
||||
variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of
|
||||
elements per vector memory access (8), data type of accumulator (float) and data type of
|
||||
computation of linear combination (alpha * X + beta * C).
|
||||
|
||||
Now that we setup the properties of data, we have to setup properties of computation.
|
||||
|
||||
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64,
|
||||
64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it
|
||||
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
|
||||
data in bank-conflict free manner, and ton of other variables required to compose, intialize and
|
||||
launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer
|
||||
from understanding and coding complicated hardware optimizations which can easily go wrong.
|
||||
|
||||
CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
|
||||
constitute the whole process of loading input data from global memory to shared memory, loading data
|
||||
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
|
||||
sequence shows a typical mma multistage pipeline.
|
||||
(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h)
|
||||
|
||||
tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers
|
||||
--mma--> registers --global stores--> output to global memory
|
||||
|
||||
NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies.
|
||||
|
||||
|
||||
There are few more template variables initialized such as, which threadblock tile of output matrix
|
||||
is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on.
|
||||
|
||||
These are all put together to create a template variable which describes CUTLASS Implicit GEMM
|
||||
kernel using cutlass::conv::device::ImplicitGemm template.
|
||||
|
||||
The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it.
|
||||
We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
|
||||
in the way of learning CUTLASS.
|
||||
|
||||
Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS
|
||||
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
|
||||
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
|
||||
important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space
|
||||
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
|
||||
arguments created to intialize CUTLASS kernel then, the kernel is launched.
|
||||
|
||||
In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to
|
||||
compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 3;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor4DCoord input_size;
|
||||
cutlass::Tensor4DCoord filter_size;
|
||||
cutlass::Tensor4DCoord padding;
|
||||
cutlass::MatrixCoord conv_stride;
|
||||
cutlass::MatrixCoord dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 32),
|
||||
padding(1, 1, 1, 1),
|
||||
conv_stride(1, 1),
|
||||
dilation(1, 1),
|
||||
reference_check(false),
|
||||
measure_performance(true),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding.h() != filter_size.h() / 2) ||
|
||||
(padding.w() != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor4DCoord input_size,
|
||||
cutlass::Tensor4DCoord filter_size) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
|
||||
padding.n() = filter_size.h() / 2;
|
||||
padding.h() = filter_size.h() / 2;
|
||||
padding.w() = filter_size.w() / 2;
|
||||
padding.c() = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = {1, 1, 1, 1};
|
||||
}
|
||||
else {
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = {0, 0, 0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "16_ampere_tensorop_conv2dfprop example\n\n"
|
||||
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n=<int> Input tensor extent N\n"
|
||||
<< " --h=<int> Input tensor extent H\n"
|
||||
<< " --w=<int> Input tensor extent W\n"
|
||||
<< " --c=<int> Input tensor extent C\n"
|
||||
<< " --k=<int> Filter extent K\n"
|
||||
<< " --r=<int> Filter extent R\n"
|
||||
<< " --s=<int> Filter extent S\n\n"
|
||||
<< " --alpha=<float> Epilogue scalar alpha\n"
|
||||
<< " --beta=<float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/16_ampere_tensorop_conv2dfprop/16_ampere_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/16_ampere_tensorop_conv2dfprop/16_ampere_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor4DCoord output_size() const {
|
||||
return cutlass::Tensor4DCoord(
|
||||
input_size.n(),
|
||||
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
||||
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(7),
|
||||
ElementInputA(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view());
|
||||
|
||||
// Fill tensor D for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
// Construct ImplicitGemm::Argument structure with conv2d
|
||||
// problem size, data pointers, and epilogue values
|
||||
typename ImplicitGemm::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_d.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemm implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on host...\n";
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::host::Conv2dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_a.host_ref(),
|
||||
tensor_b.host_ref(),
|
||||
tensor_c.host_ref(),
|
||||
tensor_ref_d.host_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_d.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "16_ampere_workspace_conv2dfprop_"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
|
||||
|
||||
struct Benchmark {
|
||||
int h, w, c, k, r, s;
|
||||
} layers[] = {
|
||||
{56, 56, 64, 256, 1, 1},
|
||||
{56, 56, 64, 64, 1, 1},
|
||||
{56, 56, 64, 64, 3, 3},
|
||||
{56, 56, 256, 64, 1, 1},
|
||||
{56, 56, 256, 512, 1, 1},
|
||||
{56, 56, 256, 128, 1, 1},
|
||||
{28, 28, 128, 128, 3, 3},
|
||||
{28, 28, 128, 512, 1, 1},
|
||||
{28, 28, 512, 128, 1, 1},
|
||||
{28, 28, 512, 1024, 1, 1},
|
||||
{28, 28, 512, 256, 1, 1},
|
||||
{14, 14, 256, 256, 3, 3},
|
||||
{14, 14, 256, 1024, 1, 1},
|
||||
{14, 14, 1024, 256, 1, 1},
|
||||
{14, 14, 1024, 2048, 1, 1},
|
||||
{14, 14, 1024, 512, 1, 1},
|
||||
{7, 7, 512, 512, 3, 3},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
|
||||
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
36
examples/17_fprop_per_channel_bias/CMakeLists.txt
Normal file
36
examples/17_fprop_per_channel_bias/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
17_fprop_per_channel_bias
|
||||
fprop_per_channel_bias.cu
|
||||
)
|
||||
|
||||
306
examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu
Normal file
306
examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu
Normal file
@ -0,0 +1,306 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
The convolution version of 12_gemm_bias_relu. Similarly, we put bias vector in Operand C and the
|
||||
rest is the same as normal convolution.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNHWC;
|
||||
using LayoutOutput = cutlass::layout::TensorNHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 4;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue, // Data type for alpha in linear combination
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // alpha X C + per channel bias
|
||||
|
||||
|
||||
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int run() {
|
||||
|
||||
// Construct Conv2dProblemSize with user defined output size
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
{1, 7, 7, 512}, // activation
|
||||
{512, 3, 3, 512}, // filter
|
||||
{1, 1, 1, 1}, // padding
|
||||
{1, 1}, // striding
|
||||
{1, 1}, // dilation
|
||||
cutlass::conv::Mode::kCrossCorrelation, // mode (convolution or cross-correlation)
|
||||
1 // split-k slices
|
||||
);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.activation_extent());
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.filter_extent());
|
||||
|
||||
// Create tensor C with dimensions 1x1x1xk which is the bias vector
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias({1, 1, 1, problem_size.K});
|
||||
|
||||
// Create tensor D used to store output from CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.output_extent());
|
||||
// Create matrix D with dimensions M x N used to store output from reference
|
||||
// kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(problem_size.output_extent());
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(4),
|
||||
ElementInputA(-4),
|
||||
0); // <- Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(4),
|
||||
ElementInputB(-4),
|
||||
0); // <- Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c_bias.host_view(),
|
||||
1,
|
||||
ElementOutput(4),
|
||||
ElementOutput(-4),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c_bias.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename ImplicitGemm::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(), // <- reference to tensor A on device
|
||||
tensor_b.device_ref(), // <- reference to tensor B on device
|
||||
// tensor C is treated as the bias vector. We can enable the CONV
|
||||
// to project away the N, H, W dimension by setting the stride to zero.
|
||||
{tensor_c_bias.device_data(), LayoutOutput::Stride(0)},
|
||||
tensor_d.device_ref(), // <- reference to tensor D on device
|
||||
{alpha} };
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
ImplicitGemm implicit_gemm_op;
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = implicit_gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = implicit_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
//
|
||||
// Create instantiation for device reference conv kernel
|
||||
//
|
||||
|
||||
// Launch device reference to compute strictly the product A * B
|
||||
cutlass::reference::device::Conv2d<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>>
|
||||
(
|
||||
cutlass::conv::Operator::kFprop,
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c_bias.device_ref(),
|
||||
tensor_ref_d.device_ref(),
|
||||
alpha, ElementComputeEpilogue(0)
|
||||
);
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
// Compute bias + relu in host code
|
||||
for (int n = 0; n < problem_size.N; ++n) {
|
||||
for (int p = 0; p < problem_size.P; ++p) {
|
||||
for (int q = 0; q < problem_size.Q; ++q) {
|
||||
for (int k = 0; k < problem_size.K; ++k) {
|
||||
|
||||
tensor_ref_d.at({n, p, q, k}) =
|
||||
std::max(ElementOutput(0),
|
||||
ElementOutput(tensor_ref_d.at({n, p, q, k}) +
|
||||
tensor_c_bias.at({0, 0, 0, k})));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view())
|
||||
? "Passed"
|
||||
: "Failed")
|
||||
<< std::endl;
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return run();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user