2
.github/ISSUE_TEMPLATE/bug_report.md
vendored
2
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -20,4 +20,4 @@ A clear and concise description of what you expected to happen.
|
||||
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
Add any other context about the problem here.
|
||||
|
||||
@ -32,4 +32,4 @@ A clear and concise description of what documentation you believe it is needed a
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Steps taken to search for needed documentation**
|
||||
List any steps you have taken:
|
||||
List any steps you have taken:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/submit_question.md
vendored
2
.github/ISSUE_TEMPLATE/submit_question.md
vendored
@ -7,4 +7,4 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**What is your question?**
|
||||
**What is your question?**
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -8,4 +8,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/labeler@main
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
@ -32,4 +32,4 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -54,4 +54,4 @@ jobs:
|
||||
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
|
||||
days-before-pr-stale: 90
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 50
|
||||
operations-per-run: 50
|
||||
|
||||
14
CHANGELOG.md
14
CHANGELOG.md
@ -1,5 +1,18 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
|
||||
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
|
||||
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
|
||||
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
|
||||
* [Grouped GEMM for Multihead Attention](examples/50_multi_head_attention)
|
||||
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
* Maxwell and Pascal GPU architectures
|
||||
* Ubuntu 16.04
|
||||
* CUDA 10.2
|
||||
|
||||
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
|
||||
|
||||
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
@ -37,6 +50,7 @@
|
||||
* Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads)
|
||||
* Updates and bugfixes from the community (thanks!)
|
||||
|
||||
|
||||
## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19)
|
||||
|
||||
* **TF32x3:** emulated single-precision using Tensor Cores
|
||||
|
||||
82
CITATION.cff
82
CITATION.cff
@ -1,82 +0,0 @@
|
||||
cff-version: 1.2.0
|
||||
title: CUTLASS
|
||||
message: >-
|
||||
If you use this software, please cite using the
|
||||
following metadata.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Andrew
|
||||
email: akerr@nvidia.com
|
||||
family-names: Kerr
|
||||
affiliation: NVIDIA
|
||||
- given-names: Haicheng
|
||||
family-names: Wu
|
||||
affiliation: NVIDIA
|
||||
email: haichengw@nvidia.com
|
||||
- given-names: Manish
|
||||
family-names: Gupta
|
||||
affiliation: Google
|
||||
email: manigupta@google.com
|
||||
- given-names: Dustyn
|
||||
family-names: Blasig
|
||||
email: dblasig@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Pradeep
|
||||
family-names: Ramini
|
||||
email: prramani@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Duane
|
||||
family-names: Merrill
|
||||
email: dumerrill@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Aniket
|
||||
family-names: Shivam
|
||||
email: ashivam@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Piotr
|
||||
family-names: Majcher
|
||||
email: pmajcher@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Paul
|
||||
family-names: Springer
|
||||
email: pspringer@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Markus
|
||||
family-names: Hohnerbach
|
||||
affiliation: NVIDIA
|
||||
email: mhohnerbach@nvidia.com
|
||||
- given-names: Jin
|
||||
family-names: Wang
|
||||
email: jinw@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Matt
|
||||
family-names: Nicely
|
||||
email: mnicely@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
repository-code: 'https://github.com/NVIDIA/cutlass'
|
||||
abstract: >-
|
||||
CUTLASS is a collection of CUDA C++ template
|
||||
abstractions for implementing high-performance
|
||||
matrix-multiplication (GEMM) and related
|
||||
computations at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical
|
||||
decomposition and data movement similar to those
|
||||
used to implement cuBLAS and cuDNN. CUTLASS
|
||||
decomposes these "moving parts" into reusable,
|
||||
modular software components abstracted by C++
|
||||
template classes. These thread-wide, warp-wide,
|
||||
block-wide, and device-wide primitives can be
|
||||
specialized and tuned via custom tiling sizes, data
|
||||
types, and other algorithmic policy. The resulting
|
||||
flexibility simplifies their use as building blocks
|
||||
within custom kernels and applications.
|
||||
keywords:
|
||||
- 'cutlass, tensor cores, cuda'
|
||||
license: BSD-3-Clause
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v2.9.0/LICENSE.txt
|
||||
version: '2.9'
|
||||
date-released: '2022-04-27'
|
||||
identifiers:
|
||||
- type: url
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v2.9.0"
|
||||
description: The GitHub release URL of tag 2.9.0
|
||||
@ -38,7 +38,7 @@ endif()
|
||||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.9.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.10.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.2)
|
||||
|
||||
@ -11,12 +11,19 @@ Andrew Kerr
|
||||
Haicheng Wu
|
||||
Manish Gupta
|
||||
Dustyn Blasig
|
||||
Pradeep Ramani
|
||||
Pradeep Ramani
|
||||
Cris Cecka
|
||||
Vijay Thakkar
|
||||
Aniket Shivam
|
||||
Honghao Lu
|
||||
Ethan Yan
|
||||
Zhaodong Chen
|
||||
Jack Kosaian
|
||||
Yujia Zhai
|
||||
Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
Jin Wang
|
||||
Aniket Shivam
|
||||
Chinmay Talegaonkar
|
||||
Shang Zhang
|
||||
Scott Yokim
|
||||
@ -53,7 +60,6 @@ Nick Zhao
|
||||
## ACKNOWLEDGEMENTS
|
||||
|
||||
Girish Bharambe
|
||||
Cris Cecka
|
||||
Luke Durant
|
||||
Olivier Giroux
|
||||
Stephen Jones
|
||||
|
||||
42
README.md
42
README.md
@ -1,8 +1,8 @@
|
||||

|
||||
|
||||
# CUTLASS 2.9
|
||||
# CUTLASS 2.10
|
||||
|
||||
_CUTLASS 2.9 - April 2022_
|
||||
_CUTLASS 2.10 - August 2022_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) and related computations at all levels
|
||||
@ -18,7 +18,9 @@ To support a wide variety of applications, CUTLASS provides extensive support fo
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32), double-precision floating
|
||||
single-precision floating point (FP32),
|
||||
[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
|
||||
double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
@ -34,26 +36,14 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.9
|
||||
# What's New in CUTLASS 2.10
|
||||
|
||||
CUTLASS 2.9 is an update to CUTLASS adding:
|
||||
- [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
|
||||
- [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
|
||||
- [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu), [HERK](/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [SYR2K](/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu), [HER2K](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
|
||||
- [Out-of-place TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu), and
|
||||
- [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu), [HEMM](/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu)
|
||||
- [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
|
||||
- [GEMM + Softmax example](/examples/35_gemm_softmax)
|
||||
- [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
|
||||
- [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. Bias Vector add is also supported in the first GEMM/CONV.
|
||||
- [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
|
||||
- [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
|
||||
- [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores.
|
||||
- Epilogue enhancement with performance improvement, more activation functions, and more fusion patterns.
|
||||
- [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix.
|
||||
- Optimal performance using [CUDA 11.7](https://developer.nvidia.com/cuda-downloads)
|
||||
- [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
|
||||
CUTLASS 2.10 is an update to CUTLASS adding:
|
||||
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
|
||||
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
|
||||
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
|
||||
- [Grouped GEMM for Multihead Attention](examples/50_multi_head_attention)
|
||||
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
|
||||
- Updates and bugfixes from the community (thanks!)
|
||||
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
|
||||
- Maxwell and Pascal GPU architectures
|
||||
@ -249,15 +239,15 @@ examples/
|
||||
|
||||
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
|
||||
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
13_fused_two_gemms/ # example demonstrating two GEMMs fused in one kernel
|
||||
|
||||
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
|
||||
|
||||
31_basic_syrk # example demonstrating Symetric rank-K update
|
||||
31_basic_syrk # example demonstrating Symmetric Rank-K update
|
||||
|
||||
32_basic_trmm #
|
||||
32_basic_trmm # example demonstrating Triangular Matrix-Matrix multiplication
|
||||
|
||||
33_ampere_3xtf32_tensorop_symm #
|
||||
33_ampere_3xtf32_tensorop_symm # example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation
|
||||
|
||||
35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores
|
||||
|
||||
|
||||
@ -54,12 +54,11 @@ using ElementInputA = cutlass::half_t; // <- data type of elements
|
||||
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = float; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices.
|
||||
// Column Major for Matrix A, B and C.
|
||||
|
||||
// Note that if the output is column major, the bias has to be per row. i.e. every row has different bias.
|
||||
// If the output is row major, the bias has to be per column, i.e. every column has different bias.
|
||||
// Below list some other notices:
|
||||
//
|
||||
// Note this example only works for ColumnMajor output because
|
||||
// 1) we only have row major epilogue.
|
||||
// 2) we swap A and B if the output is column major then we can still use the
|
||||
// row major epilogue.
|
||||
|
||||
@ -457,9 +457,13 @@ Result profile_convolution(Options const &options) {
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(7),
|
||||
ElementOutput(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
@ -686,7 +690,7 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
@ -290,7 +290,7 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
@ -326,7 +326,7 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
/**
|
||||
The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when
|
||||
computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere
|
||||
HMMA 16x8x16 FP16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
|
||||
16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
|
||||
core instructions.
|
||||
|
||||
Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h
|
||||
@ -67,9 +67,9 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
|
||||
using ElementInputA = cutlass::bfloat16_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::bfloat16_t; // Data type of elements in input tensor
|
||||
using ElementOutput = cutlass::bfloat16_t; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::ColumnMajor;
|
||||
using LayoutInputB = cutlass::layout::RowMajor;
|
||||
@ -369,22 +369,22 @@ Result profile(Options const &options) {
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(4),
|
||||
ElementInputA(-4),
|
||||
ElementInputA(2),
|
||||
ElementInputA(-2),
|
||||
0); // <- Fill tensor A on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(4),
|
||||
ElementInputB(-4),
|
||||
ElementInputB(2),
|
||||
ElementInputB(-2),
|
||||
0); // <- Fill tensor B on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(4),
|
||||
ElementOutput(-4),
|
||||
ElementOutput(2),
|
||||
ElementOutput(-2),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
@ -612,10 +612,10 @@ Result profile(Options const &options) {
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n";
|
||||
output_workspace << "Reference reduction vector= \n" << tensor_ref_reduction.host_view() << "\n\n";
|
||||
output_workspace << "Reference reduction vector = \n" << tensor_ref_reduction.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
|
||||
output_workspace << "Computed D = \n" << tensor_d.host_view() << std::endl;
|
||||
output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
@ -699,7 +699,7 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -34,3 +34,8 @@ cutlass_example_add_executable(
|
||||
ampere_fprop_mainloop_fusion.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
25_ampere_3d_fprop_mainloop_fusion
|
||||
ampere_3d_fprop_mainloop_fusion.cu
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,776 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
|
||||
This example shows how to fuse per channel scale+bias+relu of the activations
|
||||
into the 3D fprop mainloop.
|
||||
|
||||
Compared with original 3D fprop kernel, this example has two more vectors, one for
|
||||
the scale and one for the bias. The length of the vectors is the same as the
|
||||
activation channel number. This kernel loads the vectors when the associated
|
||||
activation channels are loaded in the mainloop. Between reading the
|
||||
activations and scale/bias data from the shared memory and calling tensor core
|
||||
instructions, scale+bias+relu is computed in the register file.
|
||||
|
||||
This example is customized for Ampere 16816 fp16 tensor core instruction.
|
||||
Changing to different data types or different tensor core instruction require
|
||||
source code changing. See
|
||||
include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more
|
||||
technical details.
|
||||
|
||||
This example is modified based on 25_ampere_fprop_mainloop_fusion. The command
|
||||
line is the same.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/conv/kernel/default_conv3d_fprop_fusion.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/convolution.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output tensors and computation between
|
||||
// elements
|
||||
using ElementAccumulator = float; // Data type of accumulator
|
||||
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
||||
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
||||
using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors
|
||||
using ElementOutput = float; // Data type of elements in output tensor
|
||||
|
||||
using LayoutInputA = cutlass::layout::TensorNDHWC;
|
||||
using LayoutInputB = cutlass::layout::TensorNDHWC;
|
||||
using LayoutInputScaleBias = cutlass::layout::RowMajor;
|
||||
using LayoutOutput = cutlass::layout::TensorNDHWC;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm80;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 4;
|
||||
|
||||
// This code section describe iterator algorithm selected is Analytic or Optimized
|
||||
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
||||
|
||||
// This code section describes the epilogue part of the kernel, we use default value
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
||||
// memory access. This becomes the vector width of
|
||||
// math instructions in the epilogue too.
|
||||
ElementAccumulator, // Data type of accumulator
|
||||
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
||||
|
||||
using Conv3dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv3dFpropFusion<
|
||||
ElementInputA, LayoutInputA,
|
||||
ElementInputB, LayoutInputB,
|
||||
ElementInputScaleBias, LayoutInputScaleBias,
|
||||
ElementOutput, LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
IteratorAlgorithm
|
||||
>::Kernel;
|
||||
|
||||
using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion<Conv3dFpropFusionKernel>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
cutlass::Tensor5DCoord input_size;
|
||||
cutlass::Tensor5DCoord filter_size;
|
||||
cutlass::Coord<3> padding;
|
||||
cutlass::Coord<3> conv_stride;
|
||||
cutlass::Coord<3> dilation;
|
||||
bool reference_check;
|
||||
bool measure_performance;
|
||||
int iterations;
|
||||
bool save_workspace;
|
||||
ElementComputeEpilogue alpha;
|
||||
ElementComputeEpilogue beta;
|
||||
bool benchmark;
|
||||
std::string tag;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
input_size(1, 32, 32, 32, 32),
|
||||
filter_size(32, 3, 3, 3, 32),
|
||||
padding(cutlass::make_Coord(1, 1, 1)),
|
||||
conv_stride(cutlass::make_Coord(1, 1, 1)),
|
||||
dilation(cutlass::make_Coord(1, 1, 1)),
|
||||
reference_check(true),
|
||||
measure_performance(false),
|
||||
iterations(20),
|
||||
save_workspace(false),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
benchmark(false) { }
|
||||
|
||||
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
||||
bool valid() {
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((input_size.c() % kAlignment) ||
|
||||
(filter_size.n() % kAlignment)) {
|
||||
|
||||
// misaligned tensors
|
||||
return false;
|
||||
}
|
||||
|
||||
// Invalid padding
|
||||
if ((padding[0] != filter_size.d() / 2) ||
|
||||
(padding[1] != filter_size.h() / 2) ||
|
||||
(padding[2] != filter_size.w() / 2)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Updates input and filter sizes
|
||||
void update(
|
||||
cutlass::Tensor5DCoord input_size,
|
||||
cutlass::Tensor5DCoord filter_size,
|
||||
cutlass::Coord<3> stride) {
|
||||
|
||||
this->input_size = input_size;
|
||||
this->filter_size = filter_size;
|
||||
conv_stride = stride;
|
||||
|
||||
padding[0] = filter_size.d() / 2;
|
||||
padding[1] = filter_size.h() / 2;
|
||||
padding[2] = filter_size.w() / 2;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("ref-check")) {
|
||||
reference_check = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("perf-check")) {
|
||||
measure_performance = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("save-workspace")) {
|
||||
save_workspace = true;
|
||||
}
|
||||
|
||||
if (cmd.check_cmd_line_flag("benchmark")) {
|
||||
benchmark = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("n", input_size.n());
|
||||
cmd.get_cmd_line_argument("d", input_size.d());
|
||||
cmd.get_cmd_line_argument("h", input_size.h());
|
||||
cmd.get_cmd_line_argument("w", input_size.w());
|
||||
cmd.get_cmd_line_argument("c", input_size.c());
|
||||
|
||||
cmd.get_cmd_line_argument("k", filter_size.n());
|
||||
cmd.get_cmd_line_argument("t", filter_size.d());
|
||||
cmd.get_cmd_line_argument("r", filter_size.h());
|
||||
cmd.get_cmd_line_argument("s", filter_size.w());
|
||||
filter_size.c() = input_size.c();
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("tag", tag);
|
||||
|
||||
if (filter_size.d() == 3 && filter_size.h() == 3 && filter_size.w() == 3) {
|
||||
padding = cutlass::make_Coord(1, 1, 1);
|
||||
}
|
||||
else {
|
||||
filter_size.d() = 1;
|
||||
filter_size.h() = 1;
|
||||
filter_size.w() = 1;
|
||||
padding = cutlass::make_Coord(0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "25_ampere_3d_fprop_mainloop_fusion example\n\n"
|
||||
<< " This example fuses scale+bias+relu of the activations into Ampere's\n"
|
||||
<< " Tensor Core operators on F16 data types to compute\n"
|
||||
<< " forward convolution on tensors of layout NDHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --n <int> Input tensor extent N\n"
|
||||
<< " --d <int> Input tensor extent D\n"
|
||||
<< " --h <int> Input tensor extent H\n"
|
||||
<< " --w <int> Input tensor extent W\n"
|
||||
<< " --c <int> Input tensor extent C\n"
|
||||
<< " --k <int> Filter extent K\n"
|
||||
<< " --t <int> Filter extent T\n"
|
||||
<< " --r <int> Filter extent R\n"
|
||||
<< " --s <int> Filter extent S\n\n"
|
||||
<< " --alpha <float> Epilogue scalar alpha\n"
|
||||
<< " --beta <float> Epilogue scalar beta\n\n"
|
||||
<< " --ref-check If set (true), reference check on the host is computed\n"
|
||||
<< " --perf-check If set (true), performance is measured.\n"
|
||||
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
||||
<< " --iterations <int> Number of profiling iterations to perform.\n"
|
||||
<< " --save-workspace If set, workspace is written to a text file.\n"
|
||||
<< " --tag <string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./25_ampere_3d_fprop_mainloop_fusion --n=32 --d=96 --h=96 --w=96 --c=64 --k=64 --t=1 --r=1 --s=1\n\n"
|
||||
<< "$ ./25_ampere_3d_fprop_mainloop_fusion --n=1 --d=224 --h=224 --w=224 --c=32 --k=32 --t=3 --r=3 --s=3 --ref-check\n\n"
|
||||
<< "$ ./25_ampere_3d_fprop_mainloop_fusion --n=19 --d=94 --h=96 --w=96 --c=128 --k=128 --t=1 --r=1 --s=1\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Computes the output tensor size (NPQK)
|
||||
cutlass::Tensor5DCoord output_size() const {
|
||||
return cutlass::Tensor5DCoord(
|
||||
input_size.n(),
|
||||
(input_size.d() + padding[0] + padding[0] - filter_size.d()) / conv_stride[0] + 1,
|
||||
(input_size.h() + padding[1] + padding[1] - filter_size.h()) / conv_stride[1] + 1,
|
||||
(input_size.w() + padding[2] + padding[2] - filter_size.w()) / conv_stride[2] + 1,
|
||||
filter_size.n());
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of multiply-adds = NPQK * CRS
|
||||
int64_t fmas = output_size().product() * int64_t(filter_size.d() * filter_size.h() * filter_size.w() * filter_size.c());
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Result {
|
||||
double runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cutlass::Status reference_check;
|
||||
cudaError_t error;
|
||||
|
||||
Result():
|
||||
runtime_ms(0),
|
||||
gflops(0),
|
||||
status(cutlass::Status::kSuccess),
|
||||
reference_check(cutlass::Status::kInvalid),
|
||||
error(cudaSuccess) { }
|
||||
|
||||
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << "Name,";
|
||||
}
|
||||
|
||||
out << "Layer,N,D,H,W,C,K,T,R,S,Stride_D,Stride_H,Stride_W,Runtime,GFLOPs";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
||||
|
||||
if (!options.tag.empty()) {
|
||||
out << options.tag << ",";
|
||||
}
|
||||
|
||||
out
|
||||
<< "conv_" << idx << ","
|
||||
<< options.input_size.n() << ","
|
||||
<< options.input_size.d() << ","
|
||||
<< options.input_size.h() << ","
|
||||
<< options.input_size.w() << ","
|
||||
<< options.input_size.c() << ","
|
||||
<< options.filter_size.n() << ","
|
||||
<< options.filter_size.d() << ","
|
||||
<< options.filter_size.h() << ","
|
||||
<< options.filter_size.w() << ","
|
||||
<< options.conv_stride[0] << ","
|
||||
<< options.conv_stride[1] << ","
|
||||
<< options.conv_stride[2] << ","
|
||||
<< runtime_ms << ","
|
||||
<< gflops;
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Runs one benchmark
|
||||
Result profile_convolution(Options const &options) {
|
||||
|
||||
Result result;
|
||||
|
||||
//
|
||||
// Allocate host-device tensors using the CUTLASS Utilities.
|
||||
//
|
||||
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_transformed_a(options.input_size);
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
|
||||
tensor_a_scale({1, options.input_size.c()});
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
|
||||
tensor_a_bias({1, options.input_size.c()});
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
|
||||
|
||||
//
|
||||
// Initialize tensors
|
||||
//
|
||||
|
||||
// Fill tensor A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill scale vector for tensor A on host with uniform-distribution random
|
||||
// data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a_scale.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill bias vector for tensor A on host with uniform-distribution random
|
||||
// data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a_bias.host_view(),
|
||||
1,
|
||||
ElementInputA(3),
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill tensor B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(7),
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(7),
|
||||
ElementOutput(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor D for reference on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_a_scale.sync_device();
|
||||
tensor_a_bias.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
//
|
||||
// Define arguments for CUTLASS Convolution
|
||||
//
|
||||
|
||||
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Construct Conv3dProblemSize with user defined output size
|
||||
cutlass::conv::Conv3dProblemSize problem_size(
|
||||
options.input_size,
|
||||
options.filter_size,
|
||||
options.padding,
|
||||
options.conv_stride,
|
||||
options.dilation,
|
||||
options.output_size(),
|
||||
mode,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename ImplicitGemmFusion::Arguments arguments{
|
||||
problem_size,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_a_scale.device_ref(),
|
||||
tensor_a_bias.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_d.device_ref(),
|
||||
{options.alpha, options.beta},
|
||||
};
|
||||
|
||||
//
|
||||
// Initialize CUTLASS Convolution
|
||||
//
|
||||
|
||||
ImplicitGemmFusion implicit_gemm_fusion_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
result.status = implicit_gemm_fusion_op.can_implement(arguments);
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
result.status = implicit_gemm_fusion_op();
|
||||
|
||||
CUTLASS_CHECK(result.status);
|
||||
|
||||
//
|
||||
// Optional reference check
|
||||
//
|
||||
|
||||
if (options.reference_check) {
|
||||
std::cout << "Verification on device...\n";
|
||||
|
||||
// Compute scale + bias + relu in host code
|
||||
for (int n = 0; n < options.input_size.n(); ++n) {
|
||||
for (int d = 0; d < options.input_size.d(); ++d) {
|
||||
for (int h = 0; h < options.input_size.h(); ++h) {
|
||||
for (int w = 0; w < options.input_size.w(); ++w) {
|
||||
for (int c = 0; c < options.input_size.c(); ++c) {
|
||||
tensor_transformed_a.at({n, d, h, w, c}) = std::max(
|
||||
ElementOutput(0), ElementOutput(tensor_a.at({n, d, h, w, c}) *
|
||||
tensor_a_scale.at({0, c}) +
|
||||
tensor_a_bias.at({0, c})));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor_transformed_a.sync_device();
|
||||
|
||||
// Compute with reference implementation
|
||||
cutlass::reference::device::Conv3dFprop<
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
||||
>(
|
||||
problem_size,
|
||||
tensor_transformed_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
tensor_c.device_ref(),
|
||||
tensor_ref_d.device_ref(),
|
||||
options.alpha,
|
||||
options.beta
|
||||
);
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
if (!passed) {
|
||||
result.reference_check = cutlass::Status::kErrorInternal;
|
||||
std::cout << "ERROR - results miscompared.\n";
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kSuccess;
|
||||
std::cout << "Passed.\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
result.reference_check = cutlass::Status::kInvalid;
|
||||
}
|
||||
|
||||
if (options.save_workspace) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "25_ampere_3d_fprop_mainloop_fusion"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
<< ".dat";
|
||||
|
||||
std::ofstream output_workspace(ss.str());
|
||||
|
||||
output_workspace
|
||||
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
||||
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
||||
|
||||
if (options.reference_check) {
|
||||
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
|
||||
|
||||
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
||||
}
|
||||
|
||||
//
|
||||
// Performance measurement
|
||||
//
|
||||
|
||||
if (options.measure_performance) {
|
||||
|
||||
cudaEvent_t events[2];
|
||||
|
||||
for (auto & event : events) {
|
||||
result.error = cudaEventCreate(&event);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Record an event at the start of a series of convolution operations.
|
||||
result.error = cudaEventRecord(events[0]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Launch a sequence of implicit GEMM operations on the device
|
||||
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
||||
result.status = implicit_gemm_fusion_op();
|
||||
CUTLASS_CHECK(result.status);
|
||||
}
|
||||
|
||||
// Record an event when the convolutions have been launched.
|
||||
result.error = cudaEventRecord(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Wait for work on the device to complete.
|
||||
result.error = cudaEventSynchronize(events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
||||
if (result.error != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Print average runtime and GFLOPs.
|
||||
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
||||
|
||||
// Cleanup
|
||||
for (auto event : events) {
|
||||
(void)cudaEventDestroy(event);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
bool notSupported = false;
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv3dFprop examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "This test must run on SM80 or above.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
if (notSupported) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.benchmark) {
|
||||
// Benchmark several layers
|
||||
|
||||
int batch_sizes[] = {34, 18};
|
||||
|
||||
struct Benchmark {
|
||||
int d, h, w, c, k, t, r, s, stride_d, stride_h, stride_w;
|
||||
} layers[] = {
|
||||
{56, 56, 56, 64, 256, 1, 1, 1, 1, 1, 1},
|
||||
{56, 56, 56, 64, 64, 1, 1, 1, 1, 1, 1},
|
||||
{56, 56, 56, 64, 64, 3, 3, 3, 1, 1, 1},
|
||||
{56, 56, 56, 256, 64, 1, 1, 1, 1, 1, 1},
|
||||
{56, 56, 56, 256, 512, 1, 1, 1, 2, 2, 2},
|
||||
{56, 56, 56, 256, 128, 1, 1, 1, 1, 1, 1},
|
||||
{56, 56, 56, 128, 128, 3, 3, 3, 2, 2, 2},
|
||||
{28, 28, 28, 128, 512, 1, 1, 1, 1, 1, 1},
|
||||
{28, 28, 28, 512, 128, 1, 1, 1, 1, 1, 1},
|
||||
{28, 28, 28, 128, 128, 3, 3, 3, 1, 1, 1},
|
||||
{28, 28, 28, 512, 1024, 1, 1, 1, 2, 2, 2},
|
||||
{28, 28, 28, 512, 256, 1, 1, 1, 1, 1, 1},
|
||||
{28, 28, 28, 256, 256, 3, 3, 3, 2, 2, 2},
|
||||
{14, 14, 14, 256, 1024, 1, 1, 1, 1, 1, 1},
|
||||
{14, 14, 14, 1024, 256, 1, 1, 1, 1, 1, 1},
|
||||
{14, 14, 14, 256, 256, 3, 3, 3, 1, 1, 1},
|
||||
{14, 14, 14, 1024, 2048, 1, 1, 1, 2, 2, 2},
|
||||
{14, 14, 14, 1024, 512, 1, 1, 1, 1, 1, 1},
|
||||
{14, 14, 14, 512, 512, 3, 3, 3, 2, 2, 2},
|
||||
{ 7, 7, 7, 512, 2048, 1, 1, 1, 1, 1, 1},
|
||||
{ 7, 7, 7, 2048, 512, 1, 1, 1, 1, 1, 1},
|
||||
{ 7, 7, 7, 512, 512, 3, 3, 3, 1, 1, 1},
|
||||
};
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
|
||||
int idx = 1;
|
||||
|
||||
for (auto const &layer : layers) {
|
||||
for (auto N : batch_sizes) {
|
||||
options.update({N, layer.d, layer.h, layer.w, layer.c},
|
||||
{layer.k, layer.t, layer.r, layer.s, layer.c},
|
||||
cutlass::make_Coord(layer.stride_d, layer.stride_h, layer.stride_w));
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
result.print(std::cout, idx, options) << std::endl;
|
||||
}
|
||||
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Execute one problem size
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Result result = profile_convolution(options);
|
||||
|
||||
Result::print_header(std::cout, options) << std::endl;
|
||||
result.print(std::cout, 1, options) << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -429,9 +429,13 @@ Result profile_convolution(Options const &options) {
|
||||
ElementInputB(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(7),
|
||||
ElementOutput(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
@ -575,7 +579,7 @@ Result profile_convolution(Options const &options) {
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
ss << "25_ampere_fprop_mainloop_fusion_"
|
||||
ss << "25_ampere_fprop_mainloop_fusion"
|
||||
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
||||
<< "_"
|
||||
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
||||
@ -677,8 +681,8 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major == 8 && props.minor == 0)) {
|
||||
std::cerr << "This test must run on SM80 A100.\n";
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "This test must run on SM80 or above.\n";
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
|
||||
@ -266,8 +266,8 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "26_ampere_fused_wgrad_batch_normalization example\n\n"
|
||||
<< " This example fuses scale+bias+relu from batch norm into Ampere's\n"
|
||||
out << "26_ampere_wgrad_mainloop_fusion example\n\n"
|
||||
<< " This example fuses scale+bias+relu of the activation into Ampere's\n"
|
||||
<< " Tensor Core operators on F16 data types to compute\n"
|
||||
<< " backward convolution on tensors of layout NHWC.\n\n"
|
||||
<< "Options:\n\n"
|
||||
@ -289,8 +289,8 @@ struct Options {
|
||||
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
<< "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
|
||||
<< "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -427,9 +427,13 @@ Result profile_convolution(Options const &options) {
|
||||
ElementInputA(-4),
|
||||
0);
|
||||
|
||||
// Fill tensor C on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view());
|
||||
// Fill tensor C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(7),
|
||||
ElementOutput(-8),
|
||||
0);
|
||||
|
||||
// Fill tensor D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
|
||||
@ -740,7 +740,7 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
@ -703,7 +703,7 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
@ -603,7 +603,7 @@ int main(int argc, char const **args) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
@ -47,14 +47,17 @@
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
#include "cutlass/util/reference/host/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/host/tensor_reduce.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -85,18 +88,18 @@ struct Options {
|
||||
float alpha;
|
||||
float beta;
|
||||
bool verification_enabled;
|
||||
double tolerance;
|
||||
float tolerance;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
problem_size({16, 24, 64}),
|
||||
batch_count(1), // As a temporary limitation to the test bench, batch count must be 1. The kernels support arbitrary batching.
|
||||
batch_count(16),
|
||||
iterations(20),
|
||||
seed(2022),
|
||||
alpha(1),
|
||||
beta(),
|
||||
beta(0),
|
||||
verification_enabled(true),
|
||||
tolerance(0.01)
|
||||
tolerance(1e-5f)
|
||||
{ }
|
||||
|
||||
bool valid() {
|
||||
@ -116,6 +119,8 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
|
||||
cmd.get_cmd_line_argument("batch_count", batch_count);
|
||||
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
|
||||
@ -135,6 +140,7 @@ struct Options {
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --batch_count=<int> Batch number\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --seed=<int> Random number seed (1*)\n\n"
|
||||
@ -198,13 +204,22 @@ struct Testbed {
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementSoftmax = cutlass::half_t;
|
||||
using ElementD = ElementC;
|
||||
using ElementSoftmax = ElementC;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ArchTag = cutlass::arch::Sm80;
|
||||
|
||||
static int const kStages = 3;
|
||||
|
||||
/// Linear scaling operator
|
||||
using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
@ -218,12 +233,21 @@ struct Testbed {
|
||||
ElementB, LayoutB,
|
||||
ElementC,
|
||||
ElementCompute,
|
||||
EpilogueFunctorOp
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueFunctorOp,
|
||||
kStages
|
||||
>;
|
||||
|
||||
using ElementNorm = typename GemmSoftmax::ElementNorm;
|
||||
using ElementSum = typename GemmSoftmax::ElementSum;
|
||||
using LayoutC = typename GemmSoftmax::LayoutC;
|
||||
using LayoutN = typename GemmSoftmax::LayoutN;
|
||||
using LayoutS = typename GemmSoftmax::LayoutS;
|
||||
using MatrixCoord = typename LayoutC::TensorCoord;
|
||||
|
||||
//
|
||||
// Data members
|
||||
@ -231,20 +255,42 @@ struct Testbed {
|
||||
|
||||
Options const &options;
|
||||
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||
cutlass::HostTensor<ElementD, LayoutC> tensor_D;
|
||||
cutlass::HostTensor<ElementNorm, LayoutC> tensor_N;
|
||||
cutlass::HostTensor<ElementSum, LayoutC> tensor_S;
|
||||
cutlass::HostTensor<ElementSoftmax, LayoutC> tensor_Softmax;
|
||||
|
||||
cutlass::HostTensor<ElementD, LayoutC> reference_D;
|
||||
cutlass::HostTensor<ElementNorm, LayoutC> reference_N;
|
||||
cutlass::HostTensor<ElementSoftmax, LayoutC> reference_Softmax;
|
||||
|
||||
cutlass::DeviceAllocation<ElementA> block_A;
|
||||
cutlass::DeviceAllocation<ElementB> block_B;
|
||||
cutlass::DeviceAllocation<ElementC> block_C;
|
||||
cutlass::DeviceAllocation<ElementD> block_D;
|
||||
cutlass::DeviceAllocation<ElementD> block_Ref;
|
||||
cutlass::DeviceAllocation<ElementSoftmax> block_Softmax;
|
||||
cutlass::DeviceAllocation<ElementNorm> block_Norm;
|
||||
cutlass::DeviceAllocation<ElementSum> block_Sum;
|
||||
|
||||
int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN;
|
||||
|
||||
cutlass::gemm::GemmCoord problem = options.problem_size;
|
||||
|
||||
int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0);
|
||||
int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0);
|
||||
int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0);
|
||||
|
||||
// fixed rowmajor for norm and sum
|
||||
int64_t ldn = problem.m();
|
||||
int64_t lds = ldn;
|
||||
|
||||
int64_t total_elements_A_per_batch = problem.m() * problem.k();
|
||||
int64_t total_elements_B_per_batch = problem.k() * problem.n();
|
||||
int64_t total_elements_C_per_batch = problem.m() * problem.n();
|
||||
int64_t total_elements_D_per_batch = problem.m() * problem.n();
|
||||
int64_t total_elements_partial_norm_per_batch = block_num * problem.m();
|
||||
|
||||
int64_t total_elements_A = total_elements_A_per_batch * options.batch_count;
|
||||
int64_t total_elements_B = total_elements_B_per_batch * options.batch_count;
|
||||
int64_t total_elements_C = total_elements_C_per_batch * options.batch_count;
|
||||
int64_t total_elements_D = total_elements_D_per_batch * options.batch_count;
|
||||
int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -254,20 +300,7 @@ struct Testbed {
|
||||
):
|
||||
options(options_)
|
||||
{
|
||||
|
||||
tensor_A.reset({options.problem_size.m(), options.problem_size.k()});
|
||||
tensor_B.reset({options.problem_size.k(), options.problem_size.n()});
|
||||
|
||||
tensor_C.reset({options.problem_size.m(), options.problem_size.n()});
|
||||
tensor_D.reset({options.problem_size.m(), options.problem_size.n()});
|
||||
|
||||
tensor_N.reset({block_num, options.problem_size.m()});
|
||||
tensor_S.reset({block_num, options.problem_size.m()});
|
||||
tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()});
|
||||
|
||||
reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false);
|
||||
reference_N.reset({options.problem_size.m(), 1}, false);
|
||||
reference_Softmax.reset({options.problem_size.m(), options.problem_size.n()}, false);
|
||||
}
|
||||
|
||||
/// Run
|
||||
@ -300,11 +333,6 @@ struct Testbed {
|
||||
return disposition;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the reference
|
||||
//
|
||||
compute_reference();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
@ -334,43 +362,38 @@ struct Testbed {
|
||||
/// Random initialization
|
||||
void initialize() {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_A.host_view(),
|
||||
options.seed,
|
||||
ElementD(5),
|
||||
ElementD(-5),
|
||||
0
|
||||
);
|
||||
block_A.reset(total_elements_A);
|
||||
block_B.reset(total_elements_B);
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_Softmax.reset(total_elements_D);
|
||||
block_Ref.reset(total_elements_D_per_batch);
|
||||
block_Norm.reset(total_elements_partial_norm);
|
||||
block_Sum.reset(total_elements_partial_norm);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_B.host_view(),
|
||||
options.seed + 19,
|
||||
ElementD(5),
|
||||
ElementD(-5),
|
||||
0
|
||||
);
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D.host_view(),
|
||||
ElementD()
|
||||
);
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0);
|
||||
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_N.host_view(),
|
||||
ElementNorm()
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_Softmax.host_view(),
|
||||
ElementSoftmax()
|
||||
);
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_D.sync_device();
|
||||
tensor_N.sync_device();
|
||||
tensor_S.sync_device();
|
||||
tensor_Softmax.sync_device();
|
||||
}
|
||||
|
||||
cutlass::Status execute_device_kernel() {
|
||||
@ -384,17 +407,24 @@ struct Testbed {
|
||||
GemmSoftmax::Arguments args(
|
||||
options.problem_size,
|
||||
options.batch_count,
|
||||
tensor_A.device_ref(),
|
||||
tensor_B.device_ref(),
|
||||
tensor_C.device_ref(),
|
||||
tensor_D.device_ref(),
|
||||
{block_A.get(), lda},
|
||||
{block_B.get(), ldb},
|
||||
{block_C.get(), ldc},
|
||||
{block_D.get(), ldc},
|
||||
{
|
||||
ElementCompute(options.alpha),
|
||||
ElementCompute(options.beta)
|
||||
},
|
||||
tensor_N.device_ref(),
|
||||
tensor_S.device_ref(),
|
||||
tensor_Softmax.device_ref()
|
||||
{block_Norm.get(), ldn},
|
||||
{block_Sum.get(), lds},
|
||||
{block_Softmax.get(), ldc},
|
||||
total_elements_A_per_batch,
|
||||
total_elements_B_per_batch,
|
||||
total_elements_C_per_batch,
|
||||
total_elements_D_per_batch,
|
||||
total_elements_partial_norm_per_batch,
|
||||
total_elements_partial_norm_per_batch,
|
||||
total_elements_D_per_batch
|
||||
);
|
||||
|
||||
//
|
||||
@ -415,68 +445,21 @@ struct Testbed {
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Reference calculation
|
||||
void compute_reference() {
|
||||
template<typename Element>
|
||||
bool verify_tensor(std::vector<Element> vector_Input, \
|
||||
std::vector<Element> vector_Input_Ref) {
|
||||
|
||||
// Compute GEMM
|
||||
|
||||
cutlass::reference::host::GemmComplex(
|
||||
options.problem_size,
|
||||
options.alpha,
|
||||
tensor_A.host_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B.host_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
options.beta,
|
||||
tensor_C.host_ref(),
|
||||
reference_D.host_ref(),
|
||||
double()
|
||||
);
|
||||
|
||||
// Compute the norm
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
reference_N.at({m, 0}) = reference_D.at({m, 0});
|
||||
for (int n = 1; n < options.problem_size.n(); ++n) {
|
||||
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n})));
|
||||
}
|
||||
}
|
||||
|
||||
// Compute softmax
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
|
||||
float sum = float();
|
||||
|
||||
for (int n = 0; n < options.problem_size.n(); ++n) {
|
||||
sum += std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) );
|
||||
}
|
||||
|
||||
float inv_sum = float(1.0f / sum);
|
||||
|
||||
for (int n = 0; n < options.problem_size.n(); ++n) {
|
||||
|
||||
reference_Softmax.at({m, n}) = ElementSoftmax(
|
||||
std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Emits all tensor values
|
||||
void emit_results() {
|
||||
std::cout << "D = \n" << tensor_D.host_view() << "\n\n";
|
||||
std::cout << "N = \n" << tensor_N.host_view() << "\n\n";
|
||||
std::cout << "Softmax = \n" << tensor_Softmax.host_view() << "\n\n";
|
||||
std::cout << "Reference N = \n" << reference_N.host_view() << "\n\n";
|
||||
std::cout << "Reference D = \n" << reference_D.host_view() << "\n\n";
|
||||
std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
bool verify_tensor_N(cutlass::HostTensor<ElementNorm, LayoutC> tensor_N, \
|
||||
cutlass::HostTensor<ElementNorm, LayoutC> reference_N) {
|
||||
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0}));
|
||||
if (fabs(diff) > options.tolerance) {
|
||||
int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size();
|
||||
float abs_tol = options.tolerance;
|
||||
float rel_tol = options.tolerance;
|
||||
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i));
|
||||
float abs_diff = fabs(diff);
|
||||
float abs_ref = fabs((float)vector_Input_Ref.at(i));
|
||||
float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0;
|
||||
if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) {
|
||||
printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i)));
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -488,80 +471,112 @@ struct Testbed {
|
||||
/// Verifies the reference matches
|
||||
bool verify() {
|
||||
|
||||
tensor_D.sync_host();
|
||||
tensor_N.sync_host();
|
||||
tensor_Softmax.sync_host();
|
||||
LayoutA layout_A(lda);
|
||||
LayoutB layout_B(ldb);
|
||||
LayoutC layout_C(ldc);
|
||||
LayoutN Layout_N(ldn);
|
||||
LayoutS Layout_S(lds);
|
||||
|
||||
double const kThreshold = options.tolerance;
|
||||
MatrixCoord extent_A{problem.m(), problem.k()};
|
||||
MatrixCoord extent_B{problem.k(), problem.n()};
|
||||
MatrixCoord extent_C{problem.m(), problem.n()};
|
||||
|
||||
// Verification checks - set any of these to 'true' to override the verification checks.
|
||||
bool verified_D = false;
|
||||
bool verified_N = false;
|
||||
bool verified_Softmax = false;
|
||||
for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) {
|
||||
|
||||
// Verify softmax output
|
||||
if (!verified_D) {
|
||||
cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A);
|
||||
cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B);
|
||||
cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C);
|
||||
cutlass::TensorView<ElementC, LayoutC> view_Ref_device(block_Ref.get(), layout_C, extent_C);
|
||||
|
||||
double norm_diff = cutlass::reference::host::TensorNormDiff(
|
||||
tensor_D.host_view(),
|
||||
reference_D.host_view());
|
||||
cutlass::reference::device::GemmComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementCompute, ElementCompute
|
||||
>(
|
||||
problem,
|
||||
options.alpha,
|
||||
view_A,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
view_B,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
options.beta,
|
||||
view_C,
|
||||
view_Ref_device,
|
||||
ElementCompute(0)
|
||||
);
|
||||
|
||||
double norm_reference = cutlass::reference::host::TensorNorm(
|
||||
reference_D.host_view());
|
||||
// Copy reference results to host memory for verification
|
||||
std::vector<ElementD> matrix_D_Ref(layout_C.capacity(extent_C));
|
||||
cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size());
|
||||
cutlass::TensorView<ElementD, LayoutC> view_Ref(matrix_D_Ref.data(), layout_C, extent_C);
|
||||
|
||||
double rel_error = norm_diff / norm_reference;
|
||||
std::vector<ElementSoftmax> matrix_Softmax_Ref(layout_C.capacity(extent_C));
|
||||
cutlass::TensorView<ElementSoftmax, LayoutC> view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C);
|
||||
|
||||
if (rel_error > kThreshold) {
|
||||
std::cerr << "\n\nTensor D Relative error: " << rel_error << std::endl;
|
||||
// Copy computed results to host memory
|
||||
std::vector<ElementD> matrix_D(layout_C.capacity(extent_C));
|
||||
cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size());
|
||||
|
||||
std::vector<ElementD> matrix_Softmax(layout_C.capacity(extent_C));
|
||||
cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size());
|
||||
|
||||
// Compute the norm
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
reference_N.at({m, 0}) = view_Ref.ref().at({m, 0});
|
||||
for (int n = 1; n < options.problem_size.n(); ++n) {
|
||||
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n})));
|
||||
}
|
||||
}
|
||||
else {
|
||||
verified_D = true;
|
||||
|
||||
// Compute softmax
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
|
||||
float sum = float();
|
||||
|
||||
for (int n = 0; n < options.problem_size.n(); ++n) {
|
||||
sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) );
|
||||
}
|
||||
|
||||
float inv_sum = float(1.0f / sum);
|
||||
|
||||
for (int n = 0; n < options.problem_size.n(); ++n) {
|
||||
|
||||
view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax(
|
||||
std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!verified_N) {
|
||||
verified_N = verify_tensor_N(tensor_N, reference_N);
|
||||
}
|
||||
// Verification checks - set any of these to 'true' to override the verification checks.
|
||||
bool verified_D = false;
|
||||
bool verified_Softmax = false;
|
||||
|
||||
if (!verified_Softmax) {
|
||||
|
||||
double norm_diff = cutlass::reference::host::TensorNormDiff(
|
||||
tensor_Softmax.host_view(),
|
||||
reference_Softmax.host_view());
|
||||
|
||||
double norm_reference = cutlass::reference::host::TensorNorm(
|
||||
reference_Softmax.host_view());
|
||||
|
||||
double rel_error = norm_diff / norm_reference;
|
||||
|
||||
if (rel_error > kThreshold) {
|
||||
std::cerr << "\n\nSoftmax Relative error: " << rel_error << std::endl;
|
||||
}
|
||||
else {
|
||||
verified_Softmax = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!verified_D || !verified_N || !verified_Softmax) {
|
||||
|
||||
std::cerr << "Verification check failed for tensor Softmax" << std::endl;
|
||||
|
||||
emit_results();
|
||||
|
||||
// Summarize which checks failed
|
||||
// Verify softmax output
|
||||
if (!verified_D) {
|
||||
std::cerr << "Verification of D tensor failed\n";
|
||||
}
|
||||
|
||||
if (!verified_N) {
|
||||
std::cerr << "Verification of N tensor failed\n";
|
||||
verified_D = verify_tensor<ElementC>(matrix_D, matrix_D_Ref);
|
||||
}
|
||||
|
||||
if (!verified_Softmax) {
|
||||
std::cerr << "Verification of Softmax tensor failed\n";
|
||||
verified_Softmax = verify_tensor<ElementSoftmax>(matrix_Softmax, matrix_Softmax_Ref);
|
||||
}
|
||||
|
||||
if (!verified_D || !verified_Softmax) {
|
||||
|
||||
std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n";
|
||||
|
||||
// Summarize which checks failed
|
||||
if (!verified_D) {
|
||||
std::cerr << "Verification of D tensor failed\n";
|
||||
}
|
||||
|
||||
if (!verified_Softmax) {
|
||||
std::cerr << "Verification of Softmax tensor failed\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -637,14 +652,17 @@ struct Testbed {
|
||||
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
|
||||
int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n();
|
||||
|
||||
double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9);
|
||||
double gbytes_per_second = double(bytes) * kIterations / double(elapsed_ms / 1000.0f) / double(1 << 30);
|
||||
double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9);
|
||||
double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30);
|
||||
|
||||
double elapsed_ms_per_iter = double(elapsed_ms) / kIterations;
|
||||
|
||||
std::cout << " Problem: "
|
||||
<< options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k()
|
||||
<< ", batch size: " << options.batch_count
|
||||
<< std::endl;
|
||||
|
||||
std::cout << " Runtime: " << elapsed_ms << " ms\n" << std::endl;
|
||||
std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl;
|
||||
|
||||
std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl;
|
||||
std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl;
|
||||
|
||||
@ -29,7 +29,8 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GEMM kernel to support the 'epilogue visitor' model for fusion.
|
||||
\brief GEMM kernel to support the epilogue visitor model
|
||||
for customized softmax partial reduction epilogue fusion.
|
||||
|
||||
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
|
||||
its usage has been stabilized. For now, it is included in this example to demonstrate
|
||||
@ -78,6 +79,7 @@ public:
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
@ -89,6 +91,9 @@ public:
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
using ElementNorm = typename EpilogueVisitor::ElementNorm;
|
||||
using ElementSum = typename EpilogueVisitor::ElementSum;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
@ -121,6 +126,11 @@ public:
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
ElementNorm *ptr_Max;
|
||||
ElementSum *ptr_Sum;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
@ -144,6 +154,10 @@ public:
|
||||
int batch_count_,
|
||||
TensorRefA ref_A_,
|
||||
TensorRefB ref_B_,
|
||||
TensorRefC ref_C_,
|
||||
TensorRefC ref_D_,
|
||||
ElementNorm *ptr_Max_,
|
||||
ElementSum *ptr_Sum_,
|
||||
int64_t batch_stride_A_,
|
||||
int64_t batch_stride_B_,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_
|
||||
@ -153,6 +167,10 @@ public:
|
||||
batch_count(batch_count_),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
ptr_Max(ptr_Max_),
|
||||
ptr_Sum(ptr_Sum_),
|
||||
batch_stride_A(batch_stride_A_),
|
||||
batch_stride_B(batch_stride_B_),
|
||||
epilogue_visitor(epilogue_visitor_)
|
||||
@ -174,6 +192,8 @@ public:
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
@ -181,6 +201,11 @@ public:
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
ElementC * ptr_C;
|
||||
ElementC * ptr_D;
|
||||
|
||||
ElementNorm * ptr_Max;
|
||||
ElementSum * ptr_Sum;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
@ -196,11 +221,17 @@ public:
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0)
|
||||
{ }
|
||||
@ -213,11 +244,17 @@ public:
|
||||
swizzle_log_tile(0),
|
||||
params_A(args.ref_A.layout()),
|
||||
params_B(args.ref_B.layout()),
|
||||
params_C(args.ref_C.layout()),
|
||||
params_D(args.ref_D.layout()),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(args.problem_size.k()),
|
||||
ptr_A(args.ref_A.data()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
ptr_C(args.ref_C.data()),
|
||||
ptr_D(args.ref_D.data()),
|
||||
ptr_Max(args.ptr_Max),
|
||||
ptr_Sum(args.ptr_Sum),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
epilogue_visitor(args.epilogue_visitor)
|
||||
@ -467,7 +504,14 @@ public:
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx,
|
||||
threadblock_offset);
|
||||
params.params_C,
|
||||
params.params_D,
|
||||
params.ptr_C,
|
||||
params.ptr_D,
|
||||
params.ptr_Max,
|
||||
params.ptr_Sum,
|
||||
threadblock_offset,
|
||||
blockIdx.y *params.problem_size.m() );
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
|
||||
@ -49,10 +49,12 @@
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
|
||||
#include "cutlass/reduction/kernel/reduce_softmax_final.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "epilogue_with_visitor.h"
|
||||
#include "gemm_with_epilogue_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -209,6 +211,9 @@ private:
|
||||
int idx_m = block_m + thread_m;
|
||||
int idx_n = block_n + thread_n;
|
||||
|
||||
int batch_offset_norm = block_batch * params.args.batch_stride_N;
|
||||
int batch_offset_sum = block_batch * params.args.batch_stride_S;
|
||||
|
||||
// Kill off thread if it is outside the row boundary
|
||||
if (params.args.extent.row() <= idx_m) {
|
||||
return;
|
||||
@ -251,8 +256,8 @@ private:
|
||||
params.args.batch_stride_Soft * block_batch +
|
||||
params.args.ref_Soft.layout()({idx_m, idx_n}));
|
||||
|
||||
ElementSum inv_sum = (params.args.ref_S.data())[block_m];
|
||||
ElementNorm norm = (params.args.ref_N.data())[block_m];
|
||||
ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum];
|
||||
ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm];
|
||||
|
||||
//
|
||||
// Loop
|
||||
@ -281,556 +286,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename ElementNorm_,
|
||||
typename ElementSum_,
|
||||
typename ElementSoftmaxCompute_,
|
||||
typename ThreadblockShape_
|
||||
>
|
||||
class ApplyFinalReduction {
|
||||
public:
|
||||
|
||||
using ElementNorm = ElementNorm_;
|
||||
using ElementSum = ElementSum_;
|
||||
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
using TensorRefN = TensorRef<ElementNorm, Layout>;
|
||||
using TensorRefSum = TensorRef<ElementSum, Layout>;
|
||||
|
||||
//
|
||||
// Arguments
|
||||
//
|
||||
|
||||
struct Arguments {
|
||||
|
||||
MatrixCoord extent; ///< Extent of D and Softmax matrices
|
||||
int batch_count; ///< Batch count
|
||||
TensorRefN ref_N; ///< Norm tensor (input / output)
|
||||
TensorRefSum ref_Sum; ///< Sum tensor (input / output)
|
||||
int64_t batch_stride_N; ///< Batch stride for N tensor
|
||||
int64_t batch_stride_Sum; ///< Batch stride for softmax tensor
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
batch_count(1),
|
||||
batch_stride_N(0),
|
||||
batch_stride_Sum(0)
|
||||
{ }
|
||||
|
||||
Arguments(
|
||||
MatrixCoord extent_, ///< Extent of D and Softmax matrices
|
||||
int batch_count_, ///< Batch count
|
||||
TensorRefN ref_N_, ///< Output parameter for N
|
||||
TensorRefSum ref_Sum_ , ///< Sum
|
||||
int64_t batch_stride_N_ = 0,
|
||||
int64_t batch_stride_Sum_ = 0
|
||||
):
|
||||
extent(extent_),
|
||||
batch_count(batch_count_),
|
||||
ref_N(ref_N_),
|
||||
ref_Sum(ref_Sum_),
|
||||
batch_stride_N(batch_stride_N_),
|
||||
batch_stride_Sum(batch_stride_Sum_)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct SharedStorage {
|
||||
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// Params struct
|
||||
//
|
||||
|
||||
struct Params {
|
||||
Arguments args;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Params() { }
|
||||
|
||||
Params(Arguments const &args_): args(args_) { }
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ApplyFinalReduction() { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
apply(params, shared_storage);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Partial reduction
|
||||
CUTLASS_DEVICE
|
||||
void apply(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
int threadblock_num = (params.args.extent.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN;
|
||||
|
||||
int block_batch = blockIdx.z;
|
||||
|
||||
int block_n = blockIdx.x * blockDim.x;
|
||||
|
||||
int thread_n = threadIdx.x;
|
||||
|
||||
int idx_n = block_n + thread_n;
|
||||
|
||||
if (idx_n >= params.args.extent.row()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
|
||||
using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
|
||||
|
||||
using ConvertSum = cutlass::NumericConverter<ElementSoftmaxCompute, ElementSum>;
|
||||
using ConvertNorm = cutlass::NumericConverter<ElementSoftmaxCompute, ElementNorm>;
|
||||
|
||||
ConvertSum convert_sum;
|
||||
ConvertNorm convert_norm;
|
||||
|
||||
ConvertSumOutput convert_sum_output;
|
||||
ConvertNormOutput convert_norm_output;
|
||||
|
||||
ElementNorm *access_n = params.args.ref_N.data() + params.args.batch_stride_N * block_batch + idx_n;
|
||||
ElementSum *access_s = params.args.ref_Sum.data() + params.args.batch_stride_Sum * block_batch + idx_n;
|
||||
|
||||
ElementNorm *access_n_bak = access_n;
|
||||
ElementSum *access_s_bak = access_s;
|
||||
|
||||
uint32_t float_max_bits = 0xff7fffff;
|
||||
float min_float = reinterpret_cast<float const &>(float_max_bits);
|
||||
|
||||
ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float);
|
||||
ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0);
|
||||
ElementNorm fetch_n;
|
||||
ElementSum fetch_s;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx_m = 0; idx_m < threadblock_num; idx_m++) {
|
||||
arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
|
||||
max_val = fast_max(max_val, convert_norm(fetch_n));
|
||||
access_n += params.args.extent.row();
|
||||
}
|
||||
|
||||
access_n = access_n_bak;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx_m = 0; idx_m < threadblock_num; idx_m++) {
|
||||
arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
|
||||
arch::global_load<ElementSum, sizeof(ElementSum)>(fetch_s, access_s, true);
|
||||
sum_val += convert_sum(fetch_s) * fast_exp(convert_norm(fetch_n) - max_val);
|
||||
access_n += params.args.extent.row();
|
||||
access_s += params.args.extent.row();
|
||||
}
|
||||
|
||||
ElementSoftmaxCompute inv_sum = cutlass::constants::one<ElementSoftmaxCompute>() / sum_val;
|
||||
|
||||
access_n = access_n_bak;
|
||||
access_s = access_s_bak;
|
||||
|
||||
access_n[0] = convert_norm_output(max_val);
|
||||
access_s[0] = convert_sum_output(inv_sum);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename OutputTileIterator_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementNorm_,
|
||||
typename ElementSum_,
|
||||
typename ElementSoftmaxCompute_,
|
||||
typename ElementwiseFunctor_
|
||||
>
|
||||
class EpilogueVisitorBiasMax {
|
||||
public:
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using ElementNorm = ElementNorm_;
|
||||
using ElementSum = ElementSum_;
|
||||
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
||||
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using SoftmaxFragment = Array<ElementSoftmaxCompute, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
TensorRefD ref_C;
|
||||
TensorRefD ref_D;
|
||||
ElementNorm *ptr_Max;
|
||||
ElementSum *ptr_Sum;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Max;
|
||||
int64_t batch_stride_Sum;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
batch_stride_Max(0),
|
||||
batch_stride_Sum(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
TensorRefD ref_C_,
|
||||
TensorRefD ref_D_,
|
||||
ElementNorm *ptr_Max_,
|
||||
ElementSum *ptr_Sum_,
|
||||
int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_,
|
||||
int64_t batch_stride_Max_,
|
||||
int64_t batch_stride_Sum_
|
||||
):
|
||||
elementwise(elementwise_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
ptr_Max(ptr_Max_),
|
||||
ptr_Sum(ptr_Sum_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
batch_stride_D(batch_stride_D_),
|
||||
batch_stride_Max(batch_stride_Max_),
|
||||
batch_stride_Sum(batch_stride_Sum_)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
typename OutputTileIterator::Params params_C;
|
||||
typename OutputTileIterator::Params params_D;
|
||||
typename OutputTileIterator::Element *ptr_C;
|
||||
typename OutputTileIterator::Element *ptr_D;
|
||||
ElementNorm *ptr_Max;
|
||||
ElementSum *ptr_Sum;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Max;
|
||||
int64_t batch_stride_Sum;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_D(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
elementwise(args.elementwise),
|
||||
params_C(args.ref_C.layout()),
|
||||
params_D(args.ref_D.layout()),
|
||||
ptr_C(args.ref_C.data()),
|
||||
ptr_D(args.ref_D.data()),
|
||||
ptr_Max(args.ptr_Max),
|
||||
ptr_Sum(args.ptr_Sum),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
batch_stride_Max(args.batch_stride_Max),
|
||||
batch_stride_Sum(args.batch_stride_Sum)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const & params_;
|
||||
SharedStorage & shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator alpha_;
|
||||
ElementAccumulator beta_;
|
||||
|
||||
ElementSoftmaxCompute accum_max_;
|
||||
int threadblock_row_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorBiasMax(
|
||||
Params const ¶ms, ///< Parameters routed to the epilogue
|
||||
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
|
||||
MatrixCoord const &problem_size, ///< Problem size of the output
|
||||
int thread_idx, ///< Thread index within the threadblock
|
||||
int warp_idx, ///< Warp index within the threadblock
|
||||
int lane_idx, ///< Lane index within the warp
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord(0, 0)
|
||||
):
|
||||
params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
extent_(problem_size),
|
||||
elementwise_(params.elementwise),
|
||||
iterator_C_(params.params_C, params.ptr_C, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_D_(params.params_D, params.ptr_D, problem_size, thread_idx, threadblock_offset),
|
||||
threadblock_row_(threadblock_offset.row())
|
||||
{
|
||||
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator()) {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
fragment_C_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorFragment const &accum) {
|
||||
|
||||
using Mul = cutlass::multiplies<SoftmaxFragment>;
|
||||
using Minus = cutlass::minus<SoftmaxFragment>;
|
||||
using Exp = cutlass::fast_exp_op<SoftmaxFragment>;
|
||||
|
||||
Minus minus;
|
||||
Exp exponential;
|
||||
|
||||
SoftmaxFragment result;
|
||||
|
||||
using ConvertSumOutput = cutlass::NumericConverter<ElementSoftmaxCompute, ElementSum>;
|
||||
using ConvertNormOutput = cutlass::NumericConverter<ElementSoftmaxCompute, ElementNorm>;
|
||||
|
||||
ConvertSumOutput convert_sum_output;
|
||||
ConvertNormOutput convert_norm_output;
|
||||
|
||||
NumericArrayConverter<ElementSoftmaxCompute, ElementOutput, kElementsPerAccess> source_converter;
|
||||
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
|
||||
|
||||
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
result = source_converter(elementwise_(accum));
|
||||
}else{
|
||||
result = source_converter(elementwise_(accum, source_vector));
|
||||
}
|
||||
|
||||
MatrixCoord thread_offset =
|
||||
iterator_D_.thread_start() +
|
||||
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
|
||||
int thread_in_row = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth;
|
||||
int half_thread_in_row = (thread_in_row >> 1);
|
||||
|
||||
bool column_guard = (thread_offset.column() < extent_.column());
|
||||
|
||||
// Compute the maximum within one row
|
||||
if (!column_idx) {
|
||||
// This is the first fragment in a new row
|
||||
if (column_guard) {
|
||||
accum_max_ = maximum_accumulator_(result);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// This is an additional fragment in the same row
|
||||
if (column_guard) {
|
||||
accum_max_ = maximum_accumulator_(result, accum_max_);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = half_thread_in_row; i > 0; i >>= 1) {
|
||||
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, accum_max_, i);
|
||||
accum_max_ = fast_max(accum_max_, tmp);
|
||||
}
|
||||
|
||||
SoftmaxFragment sum_frag = exponential(minus(result, accum_max_));
|
||||
|
||||
ElementSoftmaxCompute reduction_sum = sum_accumulator_(sum_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = half_thread_in_row; i > 0; i >>= 1) {
|
||||
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, reduction_sum, i);
|
||||
reduction_sum += tmp;
|
||||
}
|
||||
|
||||
bool is_write_thread = (thread_offset.row() < extent_.row() && (threadIdx.x % thread_in_row) == 0);
|
||||
ElementNorm *curr_ptr_max = params_.ptr_Max + thread_offset.row() + blockIdx.y * extent_.row();
|
||||
ElementSum *curr_ptr_sum = params_.ptr_Sum + thread_offset.row() + blockIdx.y * extent_.row();
|
||||
|
||||
arch::global_store<ElementNorm, sizeof(ElementNorm)>(
|
||||
convert_norm_output(accum_max_),
|
||||
(void *)curr_ptr_max,
|
||||
is_write_thread);
|
||||
|
||||
arch::global_store<ElementSum, sizeof(ElementSum)>(
|
||||
convert_sum_output(reduction_sum),
|
||||
(void *)curr_ptr_sum,
|
||||
is_write_thread);
|
||||
|
||||
clear_accum_max_();
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementSoftmaxCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void clear_accum_max_() {
|
||||
|
||||
uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX
|
||||
float min_float = reinterpret_cast<float const &>(float_max_bits);
|
||||
accum_max_ = ElementSoftmaxCompute(min_float);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) {
|
||||
ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
||||
sum_ += ElementSoftmaxCompute(accum[i]);
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) {
|
||||
ElementSoftmaxCompute max_ = accum[0];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < SoftmaxFragment::kElements; ++i) {
|
||||
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
|
||||
}
|
||||
|
||||
return max_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
||||
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
|
||||
}
|
||||
|
||||
return max_;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -846,10 +301,18 @@ template <
|
||||
typename LayoutB_,
|
||||
typename ElementC_,
|
||||
typename ElementCompute_,
|
||||
typename OperatorClass_,
|
||||
typename ArchTag_,
|
||||
typename ThreadblockShape_,
|
||||
typename WarpShape_,
|
||||
typename InstructionShape_,
|
||||
typename EpilogueFunctorOp_,
|
||||
int kStages_,
|
||||
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
|
||||
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
|
||||
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
|
||||
typename ElementNorm_ = float,
|
||||
typename ElementSum_ = float,
|
||||
int Alignment = 128 / cutlass::sizeof_bits<ElementA_>::value,
|
||||
typename ElementSoftmax_ = ElementC_
|
||||
>
|
||||
class GemmSoftmax {
|
||||
@ -872,8 +335,6 @@ public:
|
||||
using LayoutA = LayoutA_;
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
static int const kAlignment = Alignment;
|
||||
|
||||
using EpilogueFunctorOp = EpilogueFunctorOp_;
|
||||
using ElementNorm = ElementNorm_;
|
||||
|
||||
@ -890,13 +351,17 @@ public:
|
||||
using TensorRefSum = TensorRef<ElementSum, LayoutS>;
|
||||
using TensorRefSoft = TensorRef<ElementSoft, LayoutSoft>;
|
||||
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ArchTag = cutlass::arch::Sm80;
|
||||
static int const kStages = 3;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
|
||||
static int const kStages = kStages_;
|
||||
static int const AlignmentA = AlignmentA_;
|
||||
static int const AlignmentB = AlignmentB_;
|
||||
static int const AlignmentSoftmax = AlignmentSoftmax_;
|
||||
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
|
||||
|
||||
@ -906,10 +371,10 @@ public:
|
||||
using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignment,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignment,
|
||||
AlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementCompute,
|
||||
@ -930,7 +395,7 @@ public:
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Epilogue visitor
|
||||
using EpilogueVisitor = kernel::EpilogueVisitorBiasMax<
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax<
|
||||
ThreadblockShape,
|
||||
DefaultGemmKernel::kThreadCount,
|
||||
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
|
||||
@ -961,13 +426,13 @@ public:
|
||||
ElementSum,
|
||||
ElementSoft,
|
||||
ElementSoftmaxCompute,
|
||||
kAlignment,
|
||||
AlignmentSoftmax,
|
||||
MatrixShape<
|
||||
1, 1024
|
||||
>
|
||||
>;
|
||||
|
||||
using ApplyFinalReductionKernel = kernel::ApplyFinalReduction<
|
||||
using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
|
||||
ElementNorm,
|
||||
ElementSum,
|
||||
ElementSoftmaxCompute,
|
||||
@ -983,6 +448,7 @@ public:
|
||||
typename SoftmaxApplyKernel::Arguments softmax;
|
||||
typename ApplyFinalReductionKernel::Arguments reduction;
|
||||
cutlass::gemm::GemmCoord extend;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -1013,14 +479,14 @@ public:
|
||||
batch_count_,
|
||||
ref_A_,
|
||||
ref_B_,
|
||||
ref_C_,
|
||||
ref_D_,
|
||||
ref_N_.data(),
|
||||
ref_S_.data(),
|
||||
batch_stride_A_,
|
||||
batch_stride_B_,
|
||||
typename EpilogueVisitor::Arguments(
|
||||
linear_scaling,
|
||||
ref_C_,
|
||||
ref_D_,
|
||||
ref_N_.data(),
|
||||
ref_S_.data(),
|
||||
batch_stride_C_,
|
||||
batch_stride_D_,
|
||||
batch_stride_Max_,
|
||||
@ -1028,10 +494,9 @@ public:
|
||||
)
|
||||
),
|
||||
reduction(
|
||||
MatrixCoord(problem_size.m(), problem_size.n()),
|
||||
batch_count_,
|
||||
ref_N_,
|
||||
ref_S_,
|
||||
problem_size,
|
||||
ref_N_.data(),
|
||||
ref_S_.data(),
|
||||
batch_stride_Max_,
|
||||
batch_stride_Sum_
|
||||
),
|
||||
@ -1127,28 +592,24 @@ public:
|
||||
// Launch the ApplyFinalReductionKernel
|
||||
//
|
||||
|
||||
int threadblock_num_in_column = (params_.extend.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN;
|
||||
int thread_per_block = 128;
|
||||
int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
|
||||
if (block_per_row < 4) {
|
||||
thread_per_block = 32;
|
||||
block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
|
||||
}
|
||||
|
||||
if (threadblock_num_in_column > 1) {
|
||||
int thread_per_block = 128;
|
||||
int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
|
||||
if (block_per_row < 4) {
|
||||
thread_per_block = 32;
|
||||
block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
|
||||
}
|
||||
dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count);
|
||||
dim3 final_reduction_block(thread_per_block);
|
||||
|
||||
dim3 final_reduction_grid(block_per_row);
|
||||
dim3 final_reduction_block(thread_per_block);
|
||||
Kernel<ApplyFinalReductionKernel><<<
|
||||
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream
|
||||
>>>(params_.reduction);
|
||||
|
||||
Kernel<ApplyFinalReductionKernel><<<
|
||||
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream
|
||||
>>>(params_.reduction);
|
||||
result = cudaGetLastError();
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
if (result != cudaSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -40,18 +40,17 @@
|
||||
// for (int j = 0; j < options.index_size; ++j) {
|
||||
// int b_c_d_col = tensor_indices.at({j, 0});
|
||||
//
|
||||
// for (int k = 0; k < problem_size.k(); ++k) {
|
||||
// for (int k = 0; k < options.index_size; ++k) {
|
||||
// tensor_d_ref.at({i, b_c_d_col}) +=
|
||||
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Note that the index vector contains unique random integers with max to be N - 1
|
||||
//
|
||||
// The gather/scatter operation works best when we can still keep the biggest
|
||||
// alignment. For example, when the matrix is row major, we select rows. When
|
||||
// the matrix is column major, we selct columns.
|
||||
// the matrix is column major, we select columns.
|
||||
//
|
||||
// Not all the combination of gather and scatter are legal. For example, if A is
|
||||
// row major and C/D is column major, we cannot gather A and scatter C/D at the
|
||||
@ -257,7 +256,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
false, /*GatherA*/
|
||||
false, /*GatherA*/
|
||||
true, /*GatherB*/
|
||||
true /*ScatterD*/
|
||||
>;
|
||||
@ -353,7 +352,7 @@ int run(Options &options) {
|
||||
tensor_b.layout().stride(),
|
||||
tensor_c.layout().stride(),
|
||||
tensor_d_scattered.layout().stride(),
|
||||
nullptr, // <- pointer to index vector to gather A on device
|
||||
nullptr, // <- pointer to index vector to gather A on device
|
||||
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
|
||||
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
|
||||
|
||||
@ -392,7 +391,7 @@ int run(Options &options) {
|
||||
tensor_d_ref.at({i, b_c_d_col}) +=
|
||||
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
|
||||
}
|
||||
|
||||
|
||||
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
|
||||
}
|
||||
}
|
||||
@ -515,7 +514,7 @@ int main(int argc, const char ** argv) {
|
||||
cudaDeviceProp props;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
||||
|
||||
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
|
||||
if (!(props.major >= 8)) {
|
||||
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
||||
36
examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt
Normal file
36
examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
37_gemm_layernorm_gemm_fusion
|
||||
gemm_layernorm.cu
|
||||
)
|
||||
|
||||
937
examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu
Normal file
937
examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu
Normal file
@ -0,0 +1,937 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief CUTLASS Layernorm Example.
|
||||
|
||||
This workload provides a layer normalization example using a one-pass, square-sum-based
|
||||
variance calculation. Specifically, we fuse the reduction operation to find
|
||||
local mean and local square sum mean in the epilogue of 1st GEMM. After a light
|
||||
full reduction kernel, the mean / variance values are readily calculated for element-wise
|
||||
operations which are fused into the 2nd GEMM.
|
||||
|
||||
As stated in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data,
|
||||
the square-sum based one-pass implementation may raise concerns on numerical stability issues.
|
||||
That being said, though this fully fused layernorm example almost perfectly hides all the memory cost to
|
||||
access the intermediate matrix for layernorm computation, the numerical issue might hinder a persuasive
|
||||
usage in real-world scenarios. If that is the case, a user may turn to the stand-alone CUTLASS layernorm
|
||||
example in tools/util/include/cutlass/util/device_layernorm.h
|
||||
|
||||
Examples:
|
||||
|
||||
# Run a CUTLASS layernorm example with default setup ,
|
||||
# using the language of the transformer model as an example,
|
||||
(Column Major output matrix, hidden dimension = 768, valid word number = 4096, intermediate_scale = 4)
|
||||
$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion
|
||||
|
||||
# Run an attention example with hidden dimension = 512
|
||||
$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion --hidden_dim=512
|
||||
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/device/gemm_complex.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/gemm_complex.h"
|
||||
#include "cutlass/util/reference/host/tensor_reduce.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "gemm_with_layernorm.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class Disposition {
|
||||
kPassed,
|
||||
kIncorrect,
|
||||
kNotVerified
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Command line options parsing
|
||||
template<typename LayoutOutput_>
|
||||
struct Options {
|
||||
|
||||
using LayoutOutput = LayoutOutput_;
|
||||
|
||||
static bool const kIsColumnMajorOutput = cutlass::platform::is_same<LayoutOutput, cutlass::layout::ColumnMajor>::value;
|
||||
|
||||
bool help;
|
||||
cutlass::gemm::GemmCoord problem_size0;
|
||||
cutlass::gemm::GemmCoord problem_size1;
|
||||
int hidden_dim;
|
||||
int valid_word_num;
|
||||
int intermediate_scale;
|
||||
int iterations;
|
||||
unsigned seed;
|
||||
float alpha;
|
||||
float beta;
|
||||
bool verification_enabled;
|
||||
double tolerance;
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
iterations(20),
|
||||
seed(2022),
|
||||
hidden_dim(768),
|
||||
valid_word_num(4096),
|
||||
intermediate_scale(4),
|
||||
alpha(1),
|
||||
beta(0),
|
||||
verification_enabled(true),
|
||||
tolerance(0.01),
|
||||
problem_size1(problem_size0.m() * 4, problem_size0.n(), problem_size0.m())
|
||||
{ }
|
||||
|
||||
bool valid() {
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("hidden_dim", hidden_dim, 768);
|
||||
cmd.get_cmd_line_argument("valid_word_num", valid_word_num, 4096);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("verify", verification_enabled);
|
||||
cmd.get_cmd_line_argument("seed", seed);
|
||||
cmd.get_cmd_line_argument("tolerance", tolerance);
|
||||
|
||||
if (kIsColumnMajorOutput) {
|
||||
// column major output setup
|
||||
problem_size0.m() = hidden_dim;
|
||||
problem_size0.n() = valid_word_num;
|
||||
problem_size0.k() = hidden_dim;
|
||||
|
||||
problem_size1.m() = hidden_dim * intermediate_scale;
|
||||
problem_size1.n() = valid_word_num;
|
||||
problem_size1.k() = hidden_dim;
|
||||
}else{
|
||||
// row major output setup
|
||||
problem_size0.m() = valid_word_num;
|
||||
problem_size0.n() = hidden_dim;
|
||||
problem_size0.k() = hidden_dim;
|
||||
|
||||
problem_size1.m() = valid_word_num;
|
||||
problem_size1.n() = hidden_dim * intermediate_scale;
|
||||
problem_size1.k() = hidden_dim;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "37_gemm_layernorm_gemm_fusion example\n\n"
|
||||
<< " This example uses the CUTLASS Library to compute GEMM + Layernorm for arbitrary problem sizes.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --hidden_dim=<int> Hidden dimension\n"
|
||||
<< " --valid_word_num=<int> Valid word number\n"
|
||||
<< " --seed=<int> Random number seed (1*)\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform (0 to disable profiling).\n\n"
|
||||
<< " --verify=<bool> If true, performs reference calculation.\n\n"
|
||||
<< " --tolerance <float> Error tolerance\n"
|
||||
;
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
<< "$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion \\\n"
|
||||
<< " --hidden_dim=768 --valid_word_num=1024 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Returns true if the environment and Toolkit support this
|
||||
bool supported(bool verbose = true) const {
|
||||
|
||||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 11.0.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
if (verbose) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
if (verbose) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
if (verbose) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
||||
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
||||
//
|
||||
int const kAlignment = 8;
|
||||
|
||||
if ((problem_size0.m() % kAlignment) ||
|
||||
(problem_size0.n() % kAlignment) ||
|
||||
(problem_size0.k() % kAlignment)) {
|
||||
if (verbose) {
|
||||
std::cerr << "Misaligned input in 1st GEMM." << std::endl;
|
||||
}
|
||||
// misaligned tensors for Gemm1
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((problem_size1.m() % kAlignment) ||
|
||||
(problem_size1.n() % kAlignment) ||
|
||||
(problem_size1.k() % kAlignment)) {
|
||||
if (verbose) {
|
||||
std::cerr << "Misaligned input in 2nd GEMM." << std::endl;
|
||||
}
|
||||
// misaligned tensors for Gemm2
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
typename LayoutOutput_>
|
||||
struct Testbed {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
// User-defined data types
|
||||
using ElementInputA0 = cutlass::half_t;
|
||||
using ElementInputB0 = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
using LayoutInputA0 = cutlass::layout::RowMajor;
|
||||
using LayoutInputB0 = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = LayoutOutput_;
|
||||
|
||||
static bool const kIsColumnMajorOutput = cutlass::platform::is_same<LayoutOutput, cutlass::layout::ColumnMajor>::value;
|
||||
// turn of shifted K by default
|
||||
static bool const kIsShiftedVariance = false;
|
||||
|
||||
/// Linear scaling operator
|
||||
using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementCompute,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
static int const kStages0 = 3;
|
||||
static int const kStages1 = 4;
|
||||
|
||||
using GemmLayernorm = cutlass::GemmLayernorm<
|
||||
ElementInputA0,
|
||||
LayoutInputA0,
|
||||
ElementInputB0,
|
||||
LayoutInputB0,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementCompute,
|
||||
EpilogueFunctorOp,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages0,
|
||||
kStages1,
|
||||
kIsShiftedVariance
|
||||
>;
|
||||
|
||||
using ElementInputA1 = typename GemmLayernorm::ElementInputA1;
|
||||
using ElementOutputC1 = typename GemmLayernorm::ElementOutputC1;
|
||||
using ElementInputScaleBias = typename GemmLayernorm::ElementInputScaleBias;
|
||||
using ElementLayernormCompute = typename GemmLayernorm::ElementLayernormCompute;
|
||||
|
||||
using LayoutInputA1 = typename GemmLayernorm::LayoutInputA1;
|
||||
using LayoutOutputC0 = typename GemmLayernorm::LayoutOutputC0;
|
||||
using LayoutOutputC1 = typename GemmLayernorm::LayoutOutputC1;
|
||||
using LayoutInputScaleBias = typename GemmLayernorm::LayoutInputScaleBias;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
Options<LayoutOutput> const &options;
|
||||
|
||||
cutlass::HostTensor<ElementInputA0, LayoutInputA0> tensor_A0;
|
||||
cutlass::HostTensor<ElementInputB0, LayoutInputB0> tensor_B0;
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutputC0> tensor_C0;
|
||||
cutlass::HostTensor<ElementInputA1, LayoutInputA1> tensor_A1;
|
||||
cutlass::HostTensor<ElementOutputC1, LayoutOutputC1> tensor_C1;
|
||||
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutputC0> reference_C0;
|
||||
cutlass::HostTensor<ElementOutputC1, LayoutOutputC1> reference_C1;
|
||||
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Variance;
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Mean;
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Beta;
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Gamma;
|
||||
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> reference_Mean;
|
||||
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> reference_Variance;
|
||||
|
||||
// shifted K tensor to better ensure the numerical stability
|
||||
// According to https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
||||
// the closer shifted K to the actual mean, the better numerical stability we'll observe
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutputC0> tensor_Shifted_K;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Testbed(
|
||||
Options<LayoutOutput> const &options_
|
||||
):
|
||||
options(options_)
|
||||
{
|
||||
|
||||
tensor_A0.reset({options.problem_size0.m(), options.problem_size0.k()});
|
||||
tensor_B0.reset({options.problem_size0.k(), options.problem_size0.n()});
|
||||
|
||||
tensor_C0.reset({options.problem_size0.m(), options.problem_size0.n()});
|
||||
|
||||
tensor_A1.reset({options.problem_size1.m(), options.problem_size1.k()});
|
||||
tensor_C1.reset({options.problem_size1.m(), options.problem_size1.n()});
|
||||
|
||||
reference_C0.reset({options.problem_size0.m(), options.problem_size0.n()});
|
||||
reference_C1.reset({options.problem_size1.m(), options.problem_size1.n()});
|
||||
|
||||
int leading_dim_0 = kIsColumnMajorOutput ? options.problem_size0.n() : options.problem_size0.m();
|
||||
int leading_dim_1 = kIsColumnMajorOutput ? options.problem_size0.m() : options.problem_size0.n();
|
||||
|
||||
int block_num = (leading_dim_1 + GemmLayernorm::ThreadblockShape::kM - 1) / GemmLayernorm::ThreadblockShape::kM;
|
||||
|
||||
tensor_Variance.reset({block_num, leading_dim_0});
|
||||
tensor_Mean.reset({block_num, leading_dim_0});
|
||||
tensor_Shifted_K.reset({1, leading_dim_0});
|
||||
|
||||
tensor_Beta.reset({1, leading_dim_1});
|
||||
tensor_Gamma.reset({1, leading_dim_1});
|
||||
|
||||
reference_Mean.reset({1, leading_dim_0}, false);
|
||||
reference_Variance.reset({1, leading_dim_0}, false);
|
||||
|
||||
}
|
||||
|
||||
/// Run
|
||||
Disposition run() {
|
||||
|
||||
Disposition disposition = Disposition::kNotVerified;
|
||||
|
||||
//
|
||||
// Initialize the workspace
|
||||
//
|
||||
|
||||
initialize();
|
||||
|
||||
//
|
||||
// Launch device kernel
|
||||
//
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
|
||||
status = execute_device_kernel();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Device execution failed." << std::endl;
|
||||
return disposition;
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Device synchronize failed with error "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return disposition;
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the reference
|
||||
//
|
||||
compute_reference();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
if (options.verification_enabled) {
|
||||
|
||||
bool passed = verify();
|
||||
|
||||
if (passed) {
|
||||
disposition = Disposition::kPassed;
|
||||
}
|
||||
else {
|
||||
disposition = Disposition::kIncorrect;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Profiling
|
||||
//
|
||||
if (options.iterations) {
|
||||
profile();
|
||||
}
|
||||
|
||||
return disposition;
|
||||
}
|
||||
|
||||
/// Random initialization
|
||||
void initialize() {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_A0.host_view(),
|
||||
options.seed,
|
||||
ElementInputA0(5),
|
||||
ElementInputA0(-5),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_B0.host_view(),
|
||||
options.seed + 1,
|
||||
ElementInputB0(5),
|
||||
ElementInputB0(-5),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_A1.host_view(),
|
||||
options.seed + 2,
|
||||
ElementInputA1(5),
|
||||
ElementInputA1(-5),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_Beta.host_view(),
|
||||
options.seed + 3,
|
||||
ElementInputScaleBias(5),
|
||||
ElementInputScaleBias(-5),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_Gamma.host_view(),
|
||||
options.seed + 4,
|
||||
ElementInputScaleBias(5),
|
||||
ElementInputScaleBias(-5),
|
||||
0
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_Shifted_K.host_view(),
|
||||
options.seed + 5,
|
||||
ElementOutput(5),
|
||||
ElementOutput(-6),
|
||||
0
|
||||
);
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_A1.sync_device();
|
||||
tensor_Beta.sync_device();
|
||||
tensor_Gamma.sync_device();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
cutlass::Status execute_device_kernel() {
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
|
||||
//
|
||||
// Setup arguments
|
||||
//
|
||||
|
||||
typename GemmLayernorm::Arguments args(
|
||||
options.problem_size0,
|
||||
options.problem_size1,
|
||||
tensor_A0.device_ref().data(),
|
||||
tensor_B0.device_ref().data(),
|
||||
tensor_C0.device_ref().data(),
|
||||
tensor_C0.device_ref().data(),
|
||||
tensor_A1.device_ref().data(),
|
||||
tensor_C1.device_ref().data(),
|
||||
tensor_A0.device_ref().stride(0),
|
||||
tensor_B0.device_ref().stride(0),
|
||||
tensor_C0.device_ref().stride(0),
|
||||
tensor_C0.device_ref().stride(0),
|
||||
tensor_A1.device_ref().stride(0),
|
||||
tensor_C1.device_ref().stride(0),
|
||||
{
|
||||
ElementCompute(options.alpha),
|
||||
ElementCompute(options.beta)
|
||||
},
|
||||
tensor_Variance.device_ref(),
|
||||
tensor_Mean.device_ref(),
|
||||
tensor_Gamma.device_ref(),
|
||||
tensor_Beta.device_ref(),
|
||||
tensor_Shifted_K.device_ref().data()
|
||||
);
|
||||
|
||||
//
|
||||
// Launch
|
||||
//
|
||||
|
||||
GemmLayernorm gemm_layernorm;
|
||||
|
||||
// Initialize
|
||||
status = gemm_layernorm.initialize(args);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Run
|
||||
status = gemm_layernorm();
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Reference calculation
|
||||
void compute_reference() {
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
ElementInputA0,
|
||||
LayoutInputA0,
|
||||
ElementInputB0,
|
||||
LayoutInputB0,
|
||||
ElementOutput,
|
||||
LayoutOutputC0,
|
||||
ElementCompute,
|
||||
ElementCompute
|
||||
> gemm_device0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
ElementInputA1,
|
||||
LayoutInputA1,
|
||||
ElementOutput,
|
||||
LayoutOutputC0,
|
||||
ElementOutputC1,
|
||||
LayoutOutputC1,
|
||||
ElementCompute,
|
||||
ElementCompute
|
||||
> gemm_device1;
|
||||
|
||||
// Compute 1st GEMM
|
||||
gemm_device0(
|
||||
options.problem_size0,
|
||||
ElementCompute(options.alpha),
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
ElementCompute(options.beta),
|
||||
tensor_C0.device_ref(),
|
||||
reference_C0.device_ref()
|
||||
);
|
||||
|
||||
reference_C0.sync_host();
|
||||
|
||||
tensor_Mean.sync_host();
|
||||
tensor_Variance.sync_host();
|
||||
tensor_Gamma.sync_host();
|
||||
tensor_Beta.sync_host();
|
||||
tensor_Shifted_K.sync_host();
|
||||
|
||||
// Compute the sum and square sum for verification purpose
|
||||
if (kIsColumnMajorOutput) {
|
||||
for (int n = 0; n < options.problem_size0.n(); ++n) {
|
||||
|
||||
ElementLayernormCompute sum = ElementLayernormCompute(0);
|
||||
ElementLayernormCompute square_sum = ElementLayernormCompute(0);
|
||||
for (int m = 0; m < options.problem_size0.m(); ++m) {
|
||||
sum += ElementLayernormCompute(reference_C0.at({m, n}));
|
||||
square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n}));
|
||||
}
|
||||
|
||||
ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.m());
|
||||
ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.m());
|
||||
ElementLayernormCompute variance = cutlass::constants::one<ElementLayernormCompute>() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6) ) ;
|
||||
|
||||
mean = -mean * variance;
|
||||
|
||||
reference_Mean.at({0, n}) = ElementInputScaleBias(mean);
|
||||
reference_Variance.at({0, n}) = ElementInputScaleBias(variance);
|
||||
}
|
||||
}else{
|
||||
for (int m = 0; m < options.problem_size0.m(); ++m) {
|
||||
|
||||
ElementLayernormCompute sum = ElementLayernormCompute(0);
|
||||
ElementLayernormCompute square_sum = ElementLayernormCompute(0);
|
||||
for (int n = 0; n < options.problem_size0.n(); ++n) {
|
||||
sum += ElementLayernormCompute(reference_C0.at({m, n})) ;
|
||||
square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})) ;
|
||||
}
|
||||
|
||||
ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.n());
|
||||
ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.n());
|
||||
ElementLayernormCompute variance = cutlass::constants::one<ElementLayernormCompute>() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)) ;
|
||||
|
||||
mean = -mean * variance;
|
||||
|
||||
reference_Mean.at({0, m}) = ElementInputScaleBias(mean);
|
||||
reference_Variance.at({0, m}) = ElementInputScaleBias(variance);
|
||||
}
|
||||
}
|
||||
|
||||
// Element-wise transform for OutputC0 using 1-pass layernorm algo
|
||||
if (kIsColumnMajorOutput) {
|
||||
for (int n = 0; n < options.problem_size0.n(); ++n) {
|
||||
|
||||
ElementLayernormCompute sum = ElementLayernormCompute(0);
|
||||
for (int m = 0; m < options.problem_size0.m(); ++m) {
|
||||
sum += ElementLayernormCompute(reference_C0.at({m, n})) ;
|
||||
}
|
||||
|
||||
ElementInputScaleBias mean = ElementInputScaleBias(sum / ElementLayernormCompute(options.problem_size0.m()));
|
||||
sum = ElementLayernormCompute(0);
|
||||
for (int m = 0; m < options.problem_size0.m(); ++m) {
|
||||
sum += ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) * ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) ;
|
||||
}
|
||||
|
||||
ElementLayernormCompute square_mean = sum / ElementLayernormCompute(options.problem_size0.m());
|
||||
ElementInputScaleBias variance = ElementInputScaleBias(cutlass::constants::one<ElementLayernormCompute>()
|
||||
/ cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6))) ;
|
||||
|
||||
for (int m = 0; m < options.problem_size0.m(); ++m) {
|
||||
reference_C0.at({m, n}) =
|
||||
ElementOutput( ( (ElementInputScaleBias(reference_C0.at({m, n})) - mean) * variance )
|
||||
* tensor_Gamma.at({0, m}) + tensor_Beta.at({0, m}));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}else{
|
||||
|
||||
for (int m = 0; m < options.problem_size0.m(); ++m) {
|
||||
|
||||
float sum = float(0);
|
||||
for (int n = 0; n < options.problem_size0.n(); ++n) {
|
||||
sum += float(reference_C0.at({m, n})) ;
|
||||
}
|
||||
|
||||
float mean = sum / float(options.problem_size0.n());
|
||||
sum = float(0);
|
||||
for (int n = 0; n < options.problem_size0.n(); ++n) {
|
||||
sum += float(reference_C0.at({m, n}) - mean) * float(reference_C0.at({m, n}) - mean) ;
|
||||
}
|
||||
|
||||
float square_mean = sum / float(options.problem_size0.n());
|
||||
float variance = cutlass::constants::one<float>() / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6)) ;
|
||||
|
||||
for (int n = 0; n < options.problem_size0.n(); ++n) {
|
||||
reference_C0.at({m, n}) =
|
||||
ElementOutput( ( (float(reference_C0.at({m, n})) - mean) * variance )
|
||||
* float(tensor_Gamma.at({0, n})) + float(tensor_Beta.at({0, n})));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Sync host data with device after element-wise transform
|
||||
reference_C0.sync_device();
|
||||
|
||||
// Compute 2nd GEMM
|
||||
gemm_device1(
|
||||
options.problem_size1,
|
||||
ElementCompute(options.alpha),
|
||||
kIsColumnMajorOutput ? tensor_A1.device_ref() : reference_C0.device_ref(),
|
||||
kIsColumnMajorOutput ? reference_C0.device_ref() :tensor_A1.device_ref(),
|
||||
ElementCompute(options.beta),
|
||||
reference_C1.device_ref(),
|
||||
reference_C1.device_ref()
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
/// Emits all tensor values
|
||||
void emit_results() {
|
||||
std::cout << "tensor_C1 = \n" << tensor_C1.host_view() << "\n\n";
|
||||
std::cout << "Reference C1 = \n" << reference_C1.host_view() << "\n\n";
|
||||
std::cout << "Mean = \n" << tensor_Mean.host_view() << "\n\n";
|
||||
std::cout << "rsqrt(Variance) = \n" << tensor_Variance.host_view() << "\n\n";
|
||||
std::cout << "Reference Mean = \n" << reference_Mean.host_view() << "\n\n";
|
||||
std::cout << "Reference rsqrt(Variance) = \n" << reference_Variance.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
template<typename Element, typename Layout>
|
||||
bool verify_tensor(cutlass::HostTensor<Element, Layout> tensor, \
|
||||
cutlass::HostTensor<Element, Layout> reference,
|
||||
int leading_dim0, int leading_dim1, bool is_print = false) {
|
||||
float const kThreshold = float(options.tolerance);
|
||||
float const kAbsThreshold = 0.5f;
|
||||
float const kRelativeThreshold = 0.1f;
|
||||
// Adds a constant bias to avoid being divided by '0'
|
||||
float const kBias = 1e-5f;
|
||||
int counter = 0;
|
||||
for (int m = 0; m < leading_dim0; m++) {
|
||||
for (int n = 0; n < leading_dim1; ++n) {
|
||||
float diff = (float)(tensor.at({m, n}) - reference.at({m, n}));
|
||||
float rel_diff = fabs(diff) / fabs(reference.at({m, n}) + kBias);
|
||||
if (fabs(diff) > kAbsThreshold && rel_diff > kRelativeThreshold) {
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float err_rate = float(counter) / (float(leading_dim0) * float(leading_dim1));
|
||||
return (err_rate < kThreshold);
|
||||
}
|
||||
|
||||
/// Verifies the reference matches
|
||||
bool verify() {
|
||||
|
||||
tensor_Variance.sync_host();
|
||||
tensor_Mean.sync_host();
|
||||
tensor_C1.sync_host();
|
||||
reference_C1.sync_host();
|
||||
|
||||
// Verification checks - set any of these to 'true' to override the verification checks.
|
||||
bool verified_C1 = false;
|
||||
bool verified_Mean = false;
|
||||
bool verified_Variance = false;
|
||||
|
||||
// Verify layernorm output
|
||||
if (!verified_C1) {
|
||||
verified_C1 = verify_tensor<ElementOutputC1, LayoutOutputC1>(tensor_C1, reference_C1, options.problem_size1.m(), options.problem_size1.n());
|
||||
}
|
||||
|
||||
if (!verified_Variance) {
|
||||
verified_Variance = verify_tensor<ElementInputScaleBias, LayoutInputScaleBias>(tensor_Variance, reference_Variance, 1, options.problem_size0.n());
|
||||
}
|
||||
|
||||
if (!verified_Mean) {
|
||||
verified_Mean = verify_tensor<ElementInputScaleBias, LayoutInputScaleBias>(tensor_Mean, reference_Mean, 1, options.problem_size0.n());
|
||||
}
|
||||
|
||||
if (!verified_C1 || !verified_Mean || !verified_Variance) {
|
||||
|
||||
// emit_results();
|
||||
|
||||
std::cerr << "Verification check failed for tensor Layernorm" << std::endl;
|
||||
|
||||
// Summarize which checks failed
|
||||
if (!verified_C1) {
|
||||
std::cerr << "Verification of O tensor failed\n";
|
||||
}
|
||||
|
||||
if (!verified_Mean) {
|
||||
std::cerr << "Verification of Mean tensor failed\n";
|
||||
}
|
||||
|
||||
if (!verified_Variance) {
|
||||
std::cerr << "Verification of Variance tensor failed\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Profiles
|
||||
bool profile() {
|
||||
|
||||
//
|
||||
// Profile
|
||||
//
|
||||
|
||||
cutlass::Status status = cutlass::Status::kSuccess;
|
||||
cudaError_t result;
|
||||
cudaEvent_t events[2];
|
||||
int const kIterations = options.iterations;
|
||||
|
||||
for (cudaEvent_t &evt : events) {
|
||||
result = cudaEventCreate(&evt);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
result = cudaEventRecord(events[0]);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int iter = 0; iter < kIterations; ++iter) {
|
||||
|
||||
status = execute_device_kernel();
|
||||
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
std::cerr << "Device execution failed." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
result = cudaEventRecord(events[1]);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
float elapsed_ms = 0;
|
||||
result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]);
|
||||
|
||||
float elapsed_ms_per_iter = elapsed_ms / float(kIterations);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (cudaEvent_t &evt : events) {
|
||||
result = cudaEventDestroy(evt);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t flops = int64_t(options.problem_size0.m()) * options.problem_size0.n() * options.problem_size0.k() * 2 \
|
||||
+ int64_t(options.problem_size1.m()) * options.problem_size1.n() * options.problem_size1.k() * 2;
|
||||
|
||||
double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9);
|
||||
|
||||
std::cout << " 1st GEMM: "
|
||||
<< options.problem_size0.m() << "-by-" << options.problem_size0.n() << "-by-" << options.problem_size0.k() << "\n"
|
||||
<< " 2nd GEMM: "
|
||||
<< options.problem_size1.m() << "-by-" << options.problem_size1.n() << "-by-" << options.problem_size1.k()
|
||||
<< std::endl;
|
||||
|
||||
std::cout << " Runtime / iteration: " << elapsed_ms_per_iter << " ms\n" << std::endl;
|
||||
std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
// Define final layout
|
||||
using LayoutOutput = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Options parsing
|
||||
Options<LayoutOutput> options;
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!options.supported()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Run
|
||||
Testbed<LayoutOutput> testbed(options);
|
||||
|
||||
Disposition disposition = testbed.run();
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
switch (disposition) {
|
||||
case Disposition::kPassed:
|
||||
std::cout << "Passed" << std::endl;
|
||||
break;
|
||||
case Disposition::kIncorrect:
|
||||
std::cout << "Incorrect" << std::endl;
|
||||
break;
|
||||
case Disposition::kNotVerified:
|
||||
std::cout << "Not verified" << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
return (disposition == Disposition::kPassed ? 0 : -1);
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GEMM kernel to support the epilogue visitor model
|
||||
for customized layernorm partial reduction epilogue fusion.
|
||||
|
||||
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
|
||||
its usage has been stabilized. For now, it is included in this example to demonstrate
|
||||
some basic output fusion options.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment = const_max(
|
||||
128 / sizeof_bits<ElementA>::value,
|
||||
128 / sizeof_bits<ElementB>::value
|
||||
);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm)
|
||||
{ }
|
||||
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
GemmUniversalMode mode_,
|
||||
GemmCoord problem_size_,
|
||||
TensorRefA ref_A_,
|
||||
TensorRefB ref_B_,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_
|
||||
):
|
||||
mode(mode_),
|
||||
problem_size(problem_size_),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
epilogue_visitor(epilogue_visitor_)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr)
|
||||
{ }
|
||||
|
||||
|
||||
Params(
|
||||
Arguments const &args
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
swizzle_log_tile(0),
|
||||
params_A(args.ref_A.layout()),
|
||||
params_B(args.ref_B.layout()),
|
||||
mode(args.mode),
|
||||
gemm_k_size(args.problem_size.k()),
|
||||
ptr_A(args.ref_A.data()),
|
||||
ptr_B(args.ref_B.data()),
|
||||
epilogue_visitor(args.epilogue_visitor)
|
||||
{
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(args.problem_size.k(), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct {
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned) {
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
offset_k,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(),
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx,
|
||||
threadblock_offset);
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
1066
examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h
Normal file
1066
examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h
Normal file
File diff suppressed because it is too large
Load Diff
36
examples/38_syr2k_grouped/CMakeLists.txt
Normal file
36
examples/38_syr2k_grouped/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
38_syr2k_grouped
|
||||
syr2k_grouped.cu
|
||||
)
|
||||
|
||||
1461
examples/38_syr2k_grouped/syr2k_grouped.cu
Normal file
1461
examples/38_syr2k_grouped/syr2k_grouped.cu
Normal file
File diff suppressed because it is too large
Load Diff
36
examples/39_gemm_permute/CMakeLists.txt
Normal file
36
examples/39_gemm_permute/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
39_gemm_permute
|
||||
gemm_permute.cu
|
||||
)
|
||||
|
||||
1126
examples/39_gemm_permute/gemm_permute.cu
Normal file
1126
examples/39_gemm_permute/gemm_permute.cu
Normal file
File diff suppressed because it is too large
Load Diff
162
examples/40_cutlass_py/README.md
Normal file
162
examples/40_cutlass_py/README.md
Normal file
@ -0,0 +1,162 @@
|
||||
# CUTLASS Python Interface Example
|
||||
|
||||
## Using Docker
|
||||
You can run the PyCUTLASS on NGC pytorch container.
|
||||
```shell
|
||||
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.08-py3
|
||||
```
|
||||
PyCUTLASS requires additional dependency Boost C++ library, which can be installed with
|
||||
```bash
|
||||
apt-get update
|
||||
apt-get -y install libboost-all-dev
|
||||
```
|
||||
|
||||
|
||||
## Install the Python Interface
|
||||
The source code for python interface is allocated at `tools/library/script/pycutlass`. It requires two environment variables:
|
||||
* `CUTLASS_PATH`: the root directory of CUTLASS
|
||||
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed
|
||||
|
||||
After setting these two environment variables, PyCUTLASS can be installed with
|
||||
```shell
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh
|
||||
```
|
||||
***
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue 1: permission denied
|
||||
Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission.
|
||||
|
||||
### Issue 2: rmm: module not found
|
||||
PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps:
|
||||
```shell
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm
|
||||
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python
|
||||
python setup.py build_ext --inplace
|
||||
python setup.py install
|
||||
```
|
||||
To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues.
|
||||
|
||||
***
|
||||
For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message.
|
||||
## GEMM Examples
|
||||
The GEMM examples use numpy to create input tensors and verify the results.
|
||||
### GEMM F64 Example
|
||||
Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial
|
||||
```python
|
||||
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
```
|
||||
|
||||
### GEMM F32 Example
|
||||
Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
|
||||
```
|
||||
Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
|
||||
```
|
||||
|
||||
### GEMM F16 Example
|
||||
Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
|
||||
```
|
||||
Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
|
||||
```
|
||||
Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
|
||||
```
|
||||
|
||||
### GEMM BF16 Example
|
||||
Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel
|
||||
```python
|
||||
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
|
||||
```
|
||||
|
||||
### GEMM Int8 Example
|
||||
Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
|
||||
```python
|
||||
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
|
||||
```
|
||||
***
|
||||
## GEMM Grouped Examples
|
||||
The GEMM Grouped examples use numpy to create input tensors and verify the results.
|
||||
|
||||
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
|
||||
```
|
||||
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle2 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
|
||||
```
|
||||
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
|
||||
```python
|
||||
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle8 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
|
||||
```
|
||||
***
|
||||
## Conv2d Example
|
||||
The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/#:~:text=Aid%20to%20Ukraine.-,INSTALL%20PYTORCH,-Select%20your%20preferences).
|
||||
### Conv2d F32 Fprop
|
||||
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
```
|
||||
Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
|
||||
```python
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
|
||||
```
|
||||
### Conv2d F32 Wgrad
|
||||
Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
|
||||
```python
|
||||
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
### Conv2d F32 Dgrad
|
||||
Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
|
||||
### Conv2d F16 Fprop
|
||||
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4
|
||||
```python
|
||||
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
|
||||
```
|
||||
277
examples/40_cutlass_py/conv2d.py
Normal file
277
examples/40_cutlass_py/conv2d.py
Normal file
@ -0,0 +1,277 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass.utils import reference_model
|
||||
|
||||
import argparse
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from python")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[
|
||||
4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorC32RSK32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[
|
||||
"TensorNHWC", "TensorNC32HW32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'],
|
||||
help='Data type of computation in the epilogue')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8",
|
||||
"HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4",
|
||||
"StridedDgradHorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
# conv related
|
||||
parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'],
|
||||
help="The type of convolution: forward propagation (fprop), \
|
||||
gradient of activation (dgrad), gradient of weight (wgrad)")
|
||||
parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"],
|
||||
)
|
||||
parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str,
|
||||
choices=["analytic", "optimized", "fixed_channels", "few_channels"],
|
||||
help="This option describes iterator algorithm")
|
||||
|
||||
# arguments
|
||||
parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"],
|
||||
help="Split K Mode. Serial is used for non-splitK or serial-splitK.\
|
||||
Parallel is used for parallel splitK.")
|
||||
parser.add_argument('-k', '--split_k_slices', default=1,
|
||||
type=int, help="Number of split-k partitions. (default 1)")
|
||||
parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)")
|
||||
parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)")
|
||||
parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)")
|
||||
parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)")
|
||||
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst, args.compute_capability, args.compute_capability
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
|
||||
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
stride_support = getattr(StrideSupport, args.stride_support)
|
||||
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue, stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
operations = [operation,]
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
problem_size = cutlass.conv.Conv2dProblemSize(
|
||||
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
|
||||
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
|
||||
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
|
||||
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
|
||||
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
|
||||
cutlass.conv.Mode.cross_correlation,
|
||||
args.split_k_slices, 1
|
||||
)
|
||||
|
||||
|
||||
# User-provide inputs
|
||||
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
|
||||
conv_kind, problem_size
|
||||
)
|
||||
|
||||
if args.element_a != "int8":
|
||||
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2)
|
||||
|
||||
if args.element_b != "int8":
|
||||
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2)
|
||||
|
||||
if args.element_c != "int8":
|
||||
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5))
|
||||
else:
|
||||
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
|
||||
|
||||
tensor_D = torch.ones_like(tensor_C)
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=operation, problem_size=problem_size, A=tensor_A,
|
||||
B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta),
|
||||
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
|
||||
split_k_slices=problem_size.split_k_slices
|
||||
)
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
|
||||
reduction_arguments = ReductionArguments(
|
||||
reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
|
||||
partitions=problem_size.split_k_slices,
|
||||
workspace=arguments.ptr_D,
|
||||
destination=tensor_D,
|
||||
source=tensor_C,
|
||||
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta)
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
|
||||
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
|
||||
reduction_operation.run(reduction_arguments)
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
|
||||
|
||||
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta)
|
||||
|
||||
assert torch.equal(tensor_D, tensor_D_ref)
|
||||
|
||||
print("Passed.")
|
||||
266
examples/40_cutlass_py/gemm.py
Normal file
266
examples/40_cutlass_py/gemm.py
Normal file
@ -0,0 +1,266 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import cutlass
|
||||
from bfloat16 import bfloat16
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch CUTLASS GEMM kernels from python: 'D = alpha * A * B + beta * C'")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'],
|
||||
help="This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignement of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
|
||||
# Argument
|
||||
parser.add_argument("-p", "--problem_size",
|
||||
default=[128, 128, 128], nargs=3, type=int,
|
||||
help="GEMM problem size M, N, K")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
|
||||
help="Scaling factor of A * B")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float,
|
||||
help="Scaling factor of C")
|
||||
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
|
||||
choices=["Gemm", "GemmSplitKParallel"],
|
||||
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
|
||||
GemmSplitKParallel is used for parallel splitK")
|
||||
parser.add_argument('-k', '--split_k_slices', default=1,
|
||||
type=int, help="Number of split-k partitions. (default 1)")
|
||||
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
# parser.add_argument('-h', '--help', action="store_true",
|
||||
# help="print help information")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst, args.compute_capability, args.compute_capability
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
|
||||
operation = GemmOperationUniversal(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
operations = [operation, ]
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
reduction_operation = ReductionOperation(
|
||||
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
|
||||
C=C, element_accumulator=element_acc,
|
||||
element_compute=element_epilogue,
|
||||
count=C.alignment
|
||||
)
|
||||
operations.append(reduction_operation)
|
||||
|
||||
pycutlass.compiler.add_module(operations)
|
||||
|
||||
# User-provide inputs
|
||||
|
||||
problem_size = cutlass.gemm.GemmCoord(
|
||||
args.problem_size[0], args.problem_size[1], args.problem_size[2])
|
||||
|
||||
if args.element_a != "int8":
|
||||
if args.element_a == "bfloat16":
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(getattr(np, args.element_a))
|
||||
else:
|
||||
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.k(),)).astype(getattr(np, args.element_a))
|
||||
|
||||
if args.element_b != "int8":
|
||||
if args.element_b == "bfloat16":
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_b))
|
||||
else:
|
||||
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_c))
|
||||
|
||||
tensor_D = np.ones_like(tensor_C)
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=operation, problem_size=problem_size,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
|
||||
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta),
|
||||
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
|
||||
split_k_slices=args.split_k_slices
|
||||
)
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
reduction_arguments = ReductionArguments(
|
||||
operation=reduction_operation,
|
||||
problem_size=[problem_size.m(), problem_size.n()],
|
||||
partitions=args.split_k_slices, workspace=arguments.ptr_D,
|
||||
destination=tensor_D, source=tensor_C,
|
||||
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
|
||||
if args.gemm_mode == "GemmSplitKParallel":
|
||||
reduction_operation.run(reduction_arguments)
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
# run the host reference module
|
||||
reference = ReferenceModule(A, B, C)
|
||||
tensor_D_ref = reference.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta)
|
||||
|
||||
assert np.array_equal(tensor_D, tensor_D_ref)
|
||||
|
||||
print("Passed.")
|
||||
248
examples/40_cutlass_py/gemm_grouped.py
Normal file
248
examples/40_cutlass_py/gemm_grouped.py
Normal file
@ -0,0 +1,248 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
################################################################################
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import csv
|
||||
|
||||
import argparse
|
||||
|
||||
# parse the arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch CUTLASS GEMM Grouped kernels from python")
|
||||
|
||||
# Operation description
|
||||
# math instruction description
|
||||
parser.add_argument("-i", "--instruction_shape",
|
||||
default=[1, 1, 1], nargs=3, type=int,
|
||||
help="This option describes the size of MMA op")
|
||||
parser.add_argument("-ta", "--element_a", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor A')
|
||||
parser.add_argument("-tb", "--element_b", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor B')
|
||||
parser.add_argument("-tc", "--element_c", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of elements in input tensor C and output tensor D')
|
||||
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
|
||||
help='Data type of accumulator')
|
||||
parser.add_argument('-m', "--math", default="multiply_add",
|
||||
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
|
||||
parser.add_argument('-op', "--opcode", default="simt", type=str,
|
||||
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
|
||||
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
|
||||
# tile description
|
||||
parser.add_argument("-b", "--threadblock_shape",
|
||||
default=[128, 128, 8], nargs=3, type=int,
|
||||
help="This option describes the tile size a thread block with compute")
|
||||
parser.add_argument("-s", "--stages", default=4,
|
||||
type=int, help="Number of pipelines you want to use")
|
||||
parser.add_argument("-w", "--warp_count", default=[
|
||||
4, 2, 1], nargs=3, type=int,
|
||||
help="This option describes the number of warps along M, N, and K of the threadblock")
|
||||
parser.add_argument("-cc", "--compute_capability", default=80,
|
||||
type=int, help="This option describes CUDA SM architecture number")
|
||||
# A
|
||||
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor A")
|
||||
parser.add_argument('-aa', '--alignment_a', default=1,
|
||||
type=int, help="Memory alignment of input tensor A")
|
||||
# B
|
||||
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor B")
|
||||
parser.add_argument('-ab', '--alignment_b', default=1,
|
||||
type=int, help="Memory alignment of input tensor B")
|
||||
# C
|
||||
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
|
||||
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
|
||||
help="Memory layout of input tensor C and output tensor D")
|
||||
parser.add_argument('-ac', '--alignment_c', default=1,
|
||||
type=int, help="Memory alignment of input tensor C and output tensor D")
|
||||
# epilogue
|
||||
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
|
||||
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
|
||||
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
|
||||
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
|
||||
help="This option describes the epilogue part of the kernel")
|
||||
# swizzling
|
||||
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
|
||||
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
|
||||
help="This option describes how thread blocks are scheduled on GPU")
|
||||
# precompute mode
|
||||
parser.add_argument("-pm", "--precompute_mode",
|
||||
default="Device", type=str, choices=["Host", "Device"],
|
||||
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
|
||||
# arguments
|
||||
parser.add_argument("-p", "--problem_size_dir", type=str,
|
||||
help="path to the csv file contains the problem sizes")
|
||||
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
|
||||
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
|
||||
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
element_a = getattr(cutlass, args.element_a)
|
||||
element_b = getattr(cutlass, args.element_b)
|
||||
element_c = getattr(cutlass, args.element_c)
|
||||
element_acc = getattr(cutlass, args.element_acc)
|
||||
math_operation = getattr(MathOperation, args.math)
|
||||
opclass = getattr(cutlass.OpClass, args.opcode)
|
||||
|
||||
math_inst = MathInstruction(
|
||||
args.instruction_shape, element_a, element_b,
|
||||
element_acc, opclass, math_operation
|
||||
)
|
||||
|
||||
tile_description = TileDescription(
|
||||
args.threadblock_shape, args.stages, args.warp_count,
|
||||
math_inst, args.compute_capability, args.compute_capability
|
||||
)
|
||||
|
||||
layout_a = getattr(cutlass, args.layout_a)
|
||||
layout_b = getattr(cutlass, args.layout_b)
|
||||
layout_c = getattr(cutlass, args.layout_c)
|
||||
|
||||
A = TensorDescription(
|
||||
element_a, layout_a, args.alignment_a
|
||||
)
|
||||
|
||||
B = TensorDescription(
|
||||
element_b, layout_b, args.alignment_b
|
||||
)
|
||||
|
||||
C = TensorDescription(
|
||||
element_c, layout_c, args.alignment_c
|
||||
)
|
||||
|
||||
element_epilogue = getattr(cutlass, args.element_epilogue)
|
||||
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
|
||||
swizzling_functor = getattr(cutlass, args.swizzling_functor)
|
||||
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
|
||||
|
||||
operation = GemmOperationGrouped(
|
||||
arch=args.compute_capability, tile_description=tile_description,
|
||||
A=A, B=B, C=C, element_epilogue=element_epilogue,
|
||||
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
|
||||
precompute_mode=precompute_mode
|
||||
)
|
||||
|
||||
if args.print_cuda:
|
||||
print(operation.rt_module.emit())
|
||||
|
||||
pycutlass.compiler.add_module([operation, ])
|
||||
|
||||
reference_module = ReferenceModule(A, B, C)
|
||||
|
||||
# get problems
|
||||
problem_sizes = []
|
||||
with open(args.problem_size_dir) as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
for row in reader:
|
||||
problem_sizes.append(
|
||||
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
|
||||
)
|
||||
|
||||
problem_count = len(problem_sizes)
|
||||
|
||||
tensor_As = []
|
||||
tensor_Bs = []
|
||||
tensor_Cs = []
|
||||
tensor_Ds = []
|
||||
problem_sizes_coord = []
|
||||
tensor_D_refs = []
|
||||
|
||||
for problem_size in problem_sizes:
|
||||
if args.element_a != "int8":
|
||||
if args.element_a == "bfloat16":
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.k(),))).astype(getattr(np, args.element_a))
|
||||
else:
|
||||
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.k(),)).astype(getattr(np, args.element_a))
|
||||
|
||||
if args.element_b != "int8":
|
||||
if args.element_b == "bfloat16":
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_b))
|
||||
else:
|
||||
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_b))
|
||||
|
||||
if args.element_c != "int8":
|
||||
if args.element_c == "bfloat16":
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(bfloat16)
|
||||
else:
|
||||
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
|
||||
* problem_size.n(),))).astype(getattr(np, args.element_c))
|
||||
else:
|
||||
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
|
||||
* problem_size.n(),)).astype(getattr(np, args.element_c))
|
||||
tensor_D = np.zeros_like(tensor_C)
|
||||
|
||||
tensor_As.append(tensor_A)
|
||||
tensor_Bs.append(tensor_B)
|
||||
tensor_Cs.append(tensor_C)
|
||||
tensor_Ds.append(tensor_D)
|
||||
tensor_D_refs.append(reference_module.run(
|
||||
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta))
|
||||
problem_sizes_coord.append(problem_size)
|
||||
|
||||
arguments = GemmGroupedArguments(
|
||||
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
|
||||
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
|
||||
)
|
||||
|
||||
operation.run(arguments)
|
||||
|
||||
arguments.sync()
|
||||
|
||||
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
|
||||
assert np.array_equal(tensor_d, tensor_d_ref)
|
||||
|
||||
print("Passed.")
|
||||
3
examples/40_cutlass_py/grouped_gemm_problem_size.csv
Normal file
3
examples/40_cutlass_py/grouped_gemm_problem_size.csv
Normal file
@ -0,0 +1,3 @@
|
||||
128,128,128
|
||||
128,128,256
|
||||
512,128,384
|
||||
|
@ -1,169 +0,0 @@
|
||||
|
||||
# System modules
|
||||
import numpy as np
|
||||
import os.path
|
||||
import sys
|
||||
import ctypes
|
||||
|
||||
# CUDA Python modules
|
||||
from cuda import cuda
|
||||
from cuda import nvrtc
|
||||
|
||||
# CUTLASS modules
|
||||
import library
|
||||
import manifest as cutlass_manifest
|
||||
import generator
|
||||
import rt
|
||||
|
||||
|
||||
#
|
||||
# Construct an SGEMM
|
||||
#
|
||||
|
||||
manifest = cutlass_manifest.Manifest()
|
||||
|
||||
generator.GenerateSM50_Simt(manifest, "11.5.0")
|
||||
|
||||
#
|
||||
# Construct a GEMM operation
|
||||
#
|
||||
|
||||
operation = manifest.operations_by_name['cutlass_simt_sgemm_128x128_8x2_nt_align1']
|
||||
|
||||
#
|
||||
# Construct a runtime GEMM operation
|
||||
#
|
||||
gemm = rt.Gemm(operation)
|
||||
|
||||
#
|
||||
# Initialize context
|
||||
#
|
||||
err, = cuda.cuInit(0)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
err, device = cuda.cuDeviceGet(0)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
err, context = cuda.cuCtxCreate(0, device)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
#
|
||||
# Construct a module
|
||||
#
|
||||
|
||||
architectures = [80,]
|
||||
include_paths = [
|
||||
'../../include',
|
||||
'../../tools/util/include',
|
||||
]
|
||||
|
||||
compilation_options = rt.CompilationOptions(architectures, include_paths)
|
||||
|
||||
module = rt.Module('module.cu', [gemm], compilation_options)
|
||||
|
||||
#
|
||||
# Setup a workspace
|
||||
#
|
||||
|
||||
M, N, K = (128, 128, 128)
|
||||
|
||||
tensor_A = np.ndarray(M * K, dtype=np.float32)
|
||||
tensor_B = np.ndarray(N * K, dtype=np.float32)
|
||||
tensor_C = np.ndarray(M * N, dtype=np.float32)
|
||||
tensor_D = np.ndarray(M * N, dtype=np.float32)
|
||||
|
||||
err, tensor_A_d = cuda.cuMemAlloc(tensor_A.size * tensor_A.itemsize)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
err, tensor_B_d = cuda.cuMemAlloc(tensor_B.size * tensor_B.itemsize)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
err, tensor_C_d = cuda.cuMemAlloc(tensor_C.size * tensor_C.itemsize)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
err, tensor_D_d = cuda.cuMemAlloc(tensor_D.size * tensor_D.itemsize)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
err, stream = cuda.cuStreamCreate(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
tensors = [
|
||||
(tensor_A_d, tensor_A),
|
||||
(tensor_B_d, tensor_B),
|
||||
(tensor_C_d, tensor_C),
|
||||
(tensor_D_d, tensor_D)
|
||||
]
|
||||
|
||||
for tensor_device, tensor_host in tensors:
|
||||
bytes = tensor_host.size * tensor_host.itemsize
|
||||
print("Tensor has dimensions: %s (%d bytes)" % (str(tensor_host.size), tensor_host.itemsize))
|
||||
err, = cuda.cuMemcpyHtoDAsync(tensor_device, tensor_host, bytes, stream)
|
||||
print("updating tensor in device memory ", hex(int(tensor_device)))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('CUDA Error %s' % str(err))
|
||||
|
||||
#
|
||||
# Initialize a host buffer
|
||||
#
|
||||
|
||||
arguments = rt.GemmArguments()
|
||||
|
||||
arguments.problem_size = rt.GemmCoord(M, N, K)
|
||||
|
||||
arguments.A = rt.TensorRef(tensor_A_d, M)
|
||||
arguments.B = rt.TensorRef(tensor_B_d, N)
|
||||
arguments.C = rt.TensorRef(tensor_C_d, M)
|
||||
arguments.D = rt.TensorRef(tensor_D_d, M)
|
||||
|
||||
host_workspace = bytearray(gemm.get_host_workspace_size(arguments))
|
||||
device_workspace = None
|
||||
|
||||
launch_config = gemm.plan(arguments)
|
||||
|
||||
byte_count = gemm.initialize(host_workspace, device_workspace, launch_config, arguments)
|
||||
|
||||
#
|
||||
# Launch the kernel
|
||||
#
|
||||
|
||||
err = gemm.run(host_workspace, device_workspace, launch_config)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('CUDA Error %s' % str(err))
|
||||
|
||||
#
|
||||
# Verify results
|
||||
#
|
||||
err, = cuda.cuStreamSynchronize(stream)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("CUDA Error %s" % str(err))
|
||||
|
||||
|
||||
#
|
||||
# Debug reporting of byte array contents
|
||||
#
|
||||
|
||||
def PrintBytearray(host_workspace):
|
||||
uint_str = None
|
||||
prefix = None
|
||||
print("uint32_t host_workspace[] = {")
|
||||
for idx, byte in enumerate(host_workspace):
|
||||
if not (idx % 4):
|
||||
if uint_str is not None:
|
||||
print(prefix, uint_str, ",")
|
||||
prefix = "/* offset: %d B */ 0x" % idx
|
||||
uint_str = ""
|
||||
uint_str = "{:02x}".format(byte) + uint_str
|
||||
print("};")
|
||||
36
examples/41_multi_head_attention/CMakeLists.txt
Normal file
36
examples/41_multi_head_attention/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
41_multi_head_attention
|
||||
fused_multihead_attention.cu
|
||||
)
|
||||
|
||||
1145
examples/41_multi_head_attention/fused_multihead_attention.cu
Normal file
1145
examples/41_multi_head_attention/fused_multihead_attention.cu
Normal file
File diff suppressed because it is too large
Load Diff
626
examples/41_multi_head_attention/gemm_attention.h
Normal file
626
examples/41_multi_head_attention/gemm_attention.h
Normal file
@ -0,0 +1,626 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Defines the FusedMultiHeadAttention Class
|
||||
|
||||
The class contains the following:
|
||||
1) GEMM0 with epilogue fusion,
|
||||
2) GEMM1 with mainloop fusion, and
|
||||
3) A lightweight full softmax reduction kernel.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h"
|
||||
#include "cutlass/reduction/kernel/reduce_softmax_final.h"
|
||||
#include "gemm_grouped_with_softmax_visitor.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <
|
||||
typename ElementQ_,
|
||||
typename LayoutQ_,
|
||||
typename ElementK_,
|
||||
typename LayoutK_,
|
||||
typename ElementP_,
|
||||
typename LayoutP_,
|
||||
typename ElementCompute_,
|
||||
typename OperatorClass_,
|
||||
typename ArchTag_,
|
||||
typename ThreadblockShape0_,
|
||||
typename ThreadblockShape1_,
|
||||
typename WarpShape0_,
|
||||
typename WarpShape1_,
|
||||
typename InstructionShape_,
|
||||
int kStages0_,
|
||||
int kStages1_,
|
||||
bool UseMasking_ = false,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
|
||||
int Alignment = 128 / cutlass::sizeof_bits<ElementQ_>::value,
|
||||
typename ElementSoftmax_ = ElementP_
|
||||
>
|
||||
class FusedMultiHeadAttention {
|
||||
public:
|
||||
|
||||
using ElementQ = ElementQ_;
|
||||
using ElementK = ElementK_;
|
||||
using ElementP = ElementP_;
|
||||
using ElementV = ElementK;
|
||||
using ElementOutput = ElementP;
|
||||
using ElementAccumulator = ElementCompute_;
|
||||
|
||||
using LayoutQ = LayoutQ_;
|
||||
using LayoutK = LayoutK_;
|
||||
using LayoutP = LayoutP_;
|
||||
using LayoutV = LayoutK;
|
||||
using LayoutO = LayoutP;
|
||||
|
||||
using ElementNorm = cutlass::half_t;
|
||||
using ElementSum = cutlass::half_t;
|
||||
using ElementSoftmaxCompute = float;
|
||||
using LayoutNorm = cutlass::layout::RowMajor;
|
||||
|
||||
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
|
||||
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
|
||||
using ThreadblockShape0 = ThreadblockShape0_;
|
||||
using WarpShape0 = WarpShape0_;
|
||||
|
||||
using ThreadblockShape1 = ThreadblockShape1_;
|
||||
using WarpShape1 = WarpShape1_;
|
||||
|
||||
static int const Stages0 = kStages0_;
|
||||
static int const Stages1 = kStages1_;
|
||||
|
||||
using InstructionShape = InstructionShape_;
|
||||
|
||||
using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
|
||||
|
||||
using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>;
|
||||
|
||||
using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OperatorClass, ArchTag, ElementQ, ElementK, ElementP,
|
||||
ElementAccumulator>::Operator;
|
||||
static bool const kInternalTranspose = cutlass::platform::is_same<LayoutP, cutlass::layout::ColumnMajor>::value;
|
||||
|
||||
static bool const kUseMasking = UseMasking_;
|
||||
|
||||
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_;
|
||||
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_;
|
||||
|
||||
using MapArguments = cutlass::gemm::kernel::detail::MapArguments<
|
||||
ElementQ,
|
||||
LayoutQ,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
8,
|
||||
ElementK,
|
||||
LayoutK,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
8,
|
||||
LayoutP,
|
||||
kInternalTranspose
|
||||
>;
|
||||
|
||||
using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm<
|
||||
typename MapArguments::ElementA,
|
||||
typename MapArguments::LayoutA,
|
||||
MapArguments::kAlignmentA,
|
||||
typename MapArguments::ElementB,
|
||||
typename MapArguments::LayoutB,
|
||||
MapArguments::kAlignmentB,
|
||||
ElementP,
|
||||
typename MapArguments::LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
ThreadblockSwizzle,
|
||||
Stages0,
|
||||
true,
|
||||
Operator,
|
||||
cutlass::gemm::SharedMemoryClearOption::kNone
|
||||
>::GemmKernel;
|
||||
|
||||
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax<
|
||||
ThreadblockShape0,
|
||||
DefaultGemmKernel::kThreadCount,
|
||||
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
|
||||
typename EpilogueOutputOp0::ElementCompute,
|
||||
ElementNorm,
|
||||
ElementSum,
|
||||
ElementSoftmaxCompute,
|
||||
EpilogueOutputOp0,
|
||||
kUseMasking
|
||||
>;
|
||||
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
|
||||
EpilogueVisitor,
|
||||
typename DefaultGemmKernel::Epilogue
|
||||
>::Epilogue;
|
||||
|
||||
using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
kGroupScheduleMode0,
|
||||
kInternalTranspose,
|
||||
kUseMasking
|
||||
>;
|
||||
|
||||
using GemmGrouped0 = cutlass::gemm::device::GemmGrouped<GemmKernel0>;
|
||||
|
||||
using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
|
||||
ElementNorm,
|
||||
ElementSum,
|
||||
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute,
|
||||
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape,
|
||||
true
|
||||
>;
|
||||
|
||||
using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion<
|
||||
ElementP,
|
||||
LayoutP,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
128 / cutlass::sizeof_bits<ElementQ>::value,
|
||||
ElementV,
|
||||
LayoutV,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
128 / cutlass::sizeof_bits<ElementK>::value,
|
||||
ElementNorm,
|
||||
LayoutNorm,
|
||||
ElementOutput,
|
||||
LayoutO,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
Stages1,
|
||||
kGroupScheduleMode1
|
||||
>::GemmKernel;
|
||||
|
||||
using GemmGrouped1 = cutlass::gemm::device::GemmGrouped<GemmKernel1>;
|
||||
|
||||
public:
|
||||
|
||||
/// Arguments class
|
||||
struct Arguments {
|
||||
cutlass::gemm::GemmCoord *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_real;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
ElementQ ** ptr_Q;
|
||||
ElementK ** ptr_K;
|
||||
ElementP ** ptr_P;
|
||||
ElementP ** ptr_V;
|
||||
ElementP ** ptr_O;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
ElementP *block_P;
|
||||
ElementNorm *block_Norm;
|
||||
ElementSum *block_Sum;
|
||||
int64_t *offset_P;
|
||||
int64_t *offset_Norm_Device;
|
||||
int64_t *offset_Sum_Device;
|
||||
|
||||
typename LayoutQ::Stride::LongIndex *ldq;
|
||||
typename LayoutK::Stride::LongIndex *ldk;
|
||||
typename LayoutP::Stride::LongIndex *ldp;
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutP::Stride::LongIndex *ldo;
|
||||
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_host;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1_host;
|
||||
|
||||
ElementAccumulator alpha0;
|
||||
ElementAccumulator alpha1;
|
||||
ElementAccumulator beta;
|
||||
|
||||
int head_number;
|
||||
int batch_size;
|
||||
int seq_length;
|
||||
|
||||
typename ApplyFinalReductionDevice::Arguments reduction;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
ptr_V(nullptr),
|
||||
ptr_O(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
block_P(nullptr),
|
||||
block_Norm(nullptr),
|
||||
block_Sum(nullptr),
|
||||
offset_P(nullptr),
|
||||
offset_Norm_Device(nullptr),
|
||||
offset_Sum_Device(nullptr),
|
||||
ldq(nullptr),
|
||||
ldk(nullptr),
|
||||
ldp(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
head_number(0),
|
||||
batch_size(0),
|
||||
seq_length(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
cutlass::gemm::GemmCoord *problem_sizes0,
|
||||
cutlass::gemm::GemmCoord *problem_sizes1,
|
||||
int problem_count,
|
||||
int threadblock_count,
|
||||
ElementQ ** ptr_Q,
|
||||
ElementK ** ptr_K,
|
||||
ElementP ** ptr_P,
|
||||
ElementP ** ptr_V,
|
||||
ElementP ** ptr_O,
|
||||
ElementNorm **ptr_Max,
|
||||
ElementSum **ptr_Sum,
|
||||
ElementP *block_P,
|
||||
ElementNorm *block_Norm,
|
||||
ElementSum *block_Sum,
|
||||
int64_t *offset_P,
|
||||
int64_t *offset_Norm_Device,
|
||||
int64_t *offset_Sum_Device,
|
||||
typename LayoutQ::Stride::LongIndex *ldq,
|
||||
typename LayoutK::Stride::LongIndex *ldk,
|
||||
typename LayoutP::Stride::LongIndex *ldp,
|
||||
typename LayoutP::Stride::LongIndex *ldv,
|
||||
typename LayoutP::Stride::LongIndex *ldo,
|
||||
ElementAccumulator alpha0,
|
||||
ElementAccumulator alpha1,
|
||||
ElementAccumulator beta,
|
||||
int head_number,
|
||||
int batch_size,
|
||||
int seq_length,
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr,
|
||||
cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr,
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count),
|
||||
ptr_Q(ptr_Q),
|
||||
ptr_K(ptr_K),
|
||||
ptr_P(ptr_P),
|
||||
ptr_V(ptr_V),
|
||||
ptr_O(ptr_O),
|
||||
ptr_Max(ptr_Max),
|
||||
ptr_Sum(ptr_Sum),
|
||||
block_P(block_P),
|
||||
block_Norm(block_Norm),
|
||||
block_Sum(block_Sum),
|
||||
offset_P(offset_P),
|
||||
offset_Norm_Device(offset_Norm_Device),
|
||||
offset_Sum_Device(offset_Sum_Device),
|
||||
ldq(ldq),
|
||||
ldk(ldk),
|
||||
ldp(ldp),
|
||||
ldv(ldv),
|
||||
ldo(ldo),
|
||||
alpha0(alpha0),
|
||||
alpha1(alpha1),
|
||||
beta(beta),
|
||||
head_number(head_number),
|
||||
batch_size(batch_size),
|
||||
seq_length(seq_length),
|
||||
problem_sizes0_host(problem_sizes0_host),
|
||||
problem_sizes1_host(problem_sizes1_host),
|
||||
problem_sizes0_real(problem_sizes0_real),
|
||||
reduction(
|
||||
problem_sizes0,
|
||||
block_Norm,
|
||||
block_Sum,
|
||||
offset_Norm_Device,
|
||||
offset_Sum_Device
|
||||
)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_real;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
ElementQ ** ptr_Q;
|
||||
ElementK ** ptr_K;
|
||||
ElementP ** ptr_P;
|
||||
ElementP ** ptr_V;
|
||||
ElementP ** ptr_O;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
ElementP *block_P;
|
||||
ElementNorm *block_Norm;
|
||||
ElementSum *block_Sum;
|
||||
int64_t *offset_P;
|
||||
int64_t *offset_Norm_Device;
|
||||
int64_t *offset_Sum_Device;
|
||||
|
||||
typename LayoutQ::Stride::LongIndex *ldq;
|
||||
typename LayoutK::Stride::LongIndex *ldk;
|
||||
typename LayoutP::Stride::LongIndex *ldp;
|
||||
typename LayoutP::Stride::LongIndex *ldv;
|
||||
typename LayoutP::Stride::LongIndex *ldo;
|
||||
|
||||
cutlass::gemm::GemmCoord *problem_sizes0_host;
|
||||
cutlass::gemm::GemmCoord *problem_sizes1_host;
|
||||
|
||||
ElementAccumulator alpha0;
|
||||
ElementAccumulator alpha1;
|
||||
ElementAccumulator beta;
|
||||
|
||||
int head_number;
|
||||
int batch_size;
|
||||
int seq_length;
|
||||
|
||||
typename ApplyFinalReductionDevice::Params reduction;
|
||||
|
||||
Params():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_Q(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_P(nullptr),
|
||||
ptr_V(nullptr),
|
||||
ptr_O(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
block_P(nullptr),
|
||||
block_Norm(nullptr),
|
||||
block_Sum(nullptr),
|
||||
offset_P(nullptr),
|
||||
offset_Norm_Device(nullptr),
|
||||
offset_Sum_Device(nullptr),
|
||||
ldq(nullptr),
|
||||
ldk(nullptr),
|
||||
ldp(nullptr),
|
||||
ldv(nullptr),
|
||||
ldo(nullptr),
|
||||
problem_sizes0(nullptr),
|
||||
problem_sizes1(nullptr),
|
||||
problem_sizes0_real(nullptr),
|
||||
head_number(0),
|
||||
batch_size(0),
|
||||
seq_length(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Params(Arguments const &args, void *workspace = nullptr):
|
||||
problem_sizes0(args.problem_sizes0),
|
||||
problem_sizes1(args.problem_sizes1),
|
||||
problem_count(args.problem_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
ptr_Q(args.ptr_Q),
|
||||
ptr_K(args.ptr_K),
|
||||
ptr_P(args.ptr_P),
|
||||
ptr_V(args.ptr_V),
|
||||
ptr_O(args.ptr_O),
|
||||
ptr_Max(args.ptr_Max),
|
||||
ptr_Sum(args.ptr_Sum),
|
||||
block_P(args.block_P),
|
||||
block_Norm(args.block_Norm),
|
||||
block_Sum(args.block_Sum),
|
||||
offset_P(args.offset_P),
|
||||
offset_Norm_Device(args.offset_Norm_Device),
|
||||
offset_Sum_Device(args.offset_Sum_Device),
|
||||
ldq(args.ldq),
|
||||
ldk(args.ldk),
|
||||
ldp(args.ldp),
|
||||
ldv(args.ldv),
|
||||
ldo(args.ldo),
|
||||
problem_sizes0_host(args.problem_sizes0_host),
|
||||
problem_sizes1_host(args.problem_sizes1_host),
|
||||
problem_sizes0_real(args.problem_sizes0_real),
|
||||
alpha0(args.alpha0),
|
||||
alpha1(args.alpha1),
|
||||
beta(args.beta),
|
||||
head_number(args.head_number),
|
||||
batch_size(args.batch_size),
|
||||
seq_length(args.seq_length),
|
||||
reduction(args.reduction)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
|
||||
Params params_;
|
||||
GemmGrouped0 gemm_grouped0;
|
||||
GemmGrouped1 gemm_grouped1;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Ctor
|
||||
FusedMultiHeadAttention() {
|
||||
|
||||
}
|
||||
|
||||
/// Initialize
|
||||
Status initialize(Arguments const &args,
|
||||
void *workspace0 = nullptr,
|
||||
void *workspace1 = nullptr) {
|
||||
|
||||
params_ = Params(args);
|
||||
|
||||
typename GemmGrouped0::Arguments args_gemm0(
|
||||
params_.problem_sizes0,
|
||||
params_.problem_count,
|
||||
params_.threadblock_count,
|
||||
params_.ptr_Q,
|
||||
params_.ptr_K,
|
||||
params_.ptr_P,
|
||||
params_.ptr_P,
|
||||
params_.ptr_Max,
|
||||
params_.ptr_Sum,
|
||||
params_.ldq,
|
||||
params_.ldk,
|
||||
params_.ldp,
|
||||
params_.ldp,
|
||||
typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments(
|
||||
{
|
||||
params_.alpha0,
|
||||
params_.beta
|
||||
}
|
||||
),
|
||||
params_.problem_sizes0_host,
|
||||
params_.problem_sizes0_real
|
||||
);
|
||||
|
||||
|
||||
Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0);
|
||||
|
||||
typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta);
|
||||
|
||||
typename GemmGrouped1::Arguments args_gemm1(
|
||||
params_.problem_sizes1,
|
||||
params_.problem_count,
|
||||
params_.threadblock_count,
|
||||
epilogue_op1,
|
||||
params_.ptr_P,
|
||||
params_.ptr_V,
|
||||
params_.ptr_O,
|
||||
params_.ptr_O,
|
||||
(void**)params_.ptr_Max,
|
||||
(void**)params_.ptr_Sum,
|
||||
params_.ldp,
|
||||
params_.ldv,
|
||||
params_.ldo,
|
||||
params_.ldo,
|
||||
params_.problem_sizes1_host
|
||||
);
|
||||
|
||||
Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1);
|
||||
|
||||
if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}else{
|
||||
if (result0 != cutlass::Status::kSuccess) {
|
||||
return result0;
|
||||
}else{
|
||||
return result1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
Status result = gemm_grouped0.run();
|
||||
cudaError_t error_info;
|
||||
|
||||
if (result != cutlass::Status::kSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
|
||||
int thread_per_block = 1024;
|
||||
|
||||
dim3 final_reduction_grid(params_.head_number * params_.batch_size);
|
||||
dim3 final_reduction_block(thread_per_block);
|
||||
|
||||
cutlass::Kernel<ApplyFinalReductionDevice><<<
|
||||
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream
|
||||
>>>(params_.reduction);
|
||||
|
||||
error_info = cudaGetLastError();
|
||||
|
||||
if (error_info != cudaSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = gemm_grouped1.run();
|
||||
|
||||
if (result != cutlass::Status::kSuccess) {
|
||||
return cutlass::Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Function call operator
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@ -0,0 +1,522 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Grouped GEMM kernel with epilogue visitor customized for softmax
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
|
||||
bool Transposed_ = false,
|
||||
bool UseMask_ = false
|
||||
>
|
||||
struct GemmGroupedWithEpilogueVistor {
|
||||
public:
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using EpilogueOutputOp = typename EpilogueVisitor::ElementwiseFunctor;
|
||||
static bool const kTransposed = Transposed_;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<
|
||||
typename Mma::IteratorA::Element,
|
||||
typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA,
|
||||
Mma::IteratorA::AccessType::kElements,
|
||||
typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout,
|
||||
Mma::kTransformB,
|
||||
Mma::IteratorB::AccessType::kElements,
|
||||
typename Mma::LayoutC,
|
||||
kTransposed
|
||||
>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
using ElementNorm = typename EpilogueVisitor::ElementNorm;
|
||||
using ElementSum = typename EpilogueVisitor::ElementSum;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor = GemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
kGroupScheduleMode,
|
||||
kThreadCount,
|
||||
kThreadCount,
|
||||
kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord *problem_sizes;
|
||||
// when using mask, real problem sizes may not be aligned
|
||||
// then we need to mask out unpadded elements in softmax
|
||||
GemmCoord *problem_sizes_real;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
ElementA ** ptr_A;
|
||||
ElementB ** ptr_B;
|
||||
ElementC ** ptr_C;
|
||||
ElementC ** ptr_D;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
typename LayoutA::Stride::LongIndex *lda;
|
||||
typename LayoutB::Stride::LongIndex *ldb;
|
||||
typename LayoutC::Stride::LongIndex *ldc;
|
||||
typename LayoutC::Stride::LongIndex *ldd;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord *host_problem_sizes;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments():
|
||||
problem_count(0),
|
||||
threadblock_count(0),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
lda(nullptr),
|
||||
ldb(nullptr),
|
||||
ldc(nullptr),
|
||||
ldd(nullptr),
|
||||
host_problem_sizes(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord *problem_sizes,
|
||||
int problem_count,
|
||||
int threadblock_count,
|
||||
ElementA ** ptr_A,
|
||||
ElementB ** ptr_B,
|
||||
ElementC ** ptr_C,
|
||||
ElementC ** ptr_D,
|
||||
ElementNorm **ptr_Max,
|
||||
ElementSum **ptr_Sum,
|
||||
typename LayoutA::Stride::LongIndex *lda,
|
||||
typename LayoutB::Stride::LongIndex *ldb,
|
||||
typename LayoutC::Stride::LongIndex *ldc,
|
||||
typename LayoutC::Stride::LongIndex *ldd,
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor_,
|
||||
GemmCoord *host_problem_sizes=nullptr,
|
||||
GemmCoord *problem_sizes_real=nullptr
|
||||
):
|
||||
problem_sizes(problem_sizes),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count),
|
||||
ptr_A(ptr_A),
|
||||
ptr_B(ptr_B),
|
||||
ptr_C(ptr_C),
|
||||
ptr_D(ptr_D),
|
||||
ptr_Max(ptr_Max),
|
||||
ptr_Sum(ptr_Sum),
|
||||
lda(lda),
|
||||
ldb(ldb),
|
||||
ldc(ldc),
|
||||
ldd(ldd),
|
||||
epilogue_visitor(epilogue_visitor_),
|
||||
host_problem_sizes(host_problem_sizes),
|
||||
problem_sizes_real(problem_sizes_real)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
GemmCoord *problem_sizes_real;
|
||||
int threadblock_count;
|
||||
|
||||
ElementA ** ptr_A;
|
||||
ElementB ** ptr_B;
|
||||
ElementC ** ptr_C;
|
||||
ElementC ** ptr_D;
|
||||
|
||||
ElementNorm **ptr_Max;
|
||||
ElementSum **ptr_Sum;
|
||||
|
||||
typename LayoutA::Stride::LongIndex *lda;
|
||||
typename LayoutB::Stride::LongIndex *ldb;
|
||||
typename LayoutC::Stride::LongIndex *ldc;
|
||||
typename LayoutC::Stride::LongIndex *ldd;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
ptr_Max(nullptr),
|
||||
ptr_Sum(nullptr),
|
||||
lda(nullptr),
|
||||
ldb(nullptr),
|
||||
ldc(nullptr),
|
||||
ldd(nullptr),
|
||||
problem_sizes_real(problem_sizes_real)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args, void *workspace = nullptr, int32_t tile_count = 0):
|
||||
problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
ptr_A(args.ptr_A),
|
||||
ptr_B(args.ptr_B),
|
||||
ptr_C(args.ptr_C),
|
||||
ptr_D(args.ptr_D),
|
||||
ptr_Max(args.ptr_Max),
|
||||
ptr_Sum(args.ptr_Sum),
|
||||
lda(args.lda),
|
||||
ldb(args.ldb),
|
||||
ldc(args.ldc),
|
||||
ldd(args.ldd),
|
||||
epilogue_visitor(args.epilogue_visitor),
|
||||
problem_sizes_real(args.problem_sizes_real)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
int32_t tile_count = -1) {
|
||||
|
||||
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_Max = args.ptr_Max;
|
||||
ptr_Sum = args.ptr_Sum;
|
||||
lda = args.lda;
|
||||
ldb = args.ldb;
|
||||
ldc = args.ldc;
|
||||
ldd = args.ldd;
|
||||
problem_sizes_real = args.problem_sizes_real;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage {
|
||||
union {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
struct {
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
} kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmGroupedWithEpilogueVistor() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Mma::LayoutC;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(
|
||||
params.problem_visitor,
|
||||
shared_storage.problem_visitor,
|
||||
blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile()) {
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
|
||||
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN,
|
||||
0);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
ElementA *ptr_A = reinterpret_cast<ElementA *>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
|
||||
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
|
||||
|
||||
ElementB *ptr_B = reinterpret_cast<ElementB *>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
|
||||
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
0,
|
||||
threadblock_offset.n()
|
||||
};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A),
|
||||
ptr_A,
|
||||
{problem_size.m(), problem_size.k()},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B),
|
||||
ptr_B,
|
||||
{problem_size.k(), problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(
|
||||
gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
accumulators);
|
||||
|
||||
ElementC *ptr_C = params.ptr_C[problem_idx];
|
||||
ElementC *ptr_D = params.ptr_D[problem_idx];
|
||||
|
||||
ElementNorm *ptr_Max = params.ptr_Max[problem_idx];
|
||||
ElementSum *ptr_Sum = params.ptr_Sum[problem_idx];
|
||||
|
||||
LayoutC layout_C(params.ldc[problem_idx]);
|
||||
LayoutC layout_D(params.ldd[problem_idx]);
|
||||
|
||||
int column_offset = (threadblock_offset.n() / ThreadblockShape::kN) * problem_size.m();
|
||||
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C(layout_C);
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
EpilogueVisitor epilogue_visitor(
|
||||
params.epilogue_visitor,
|
||||
shared_storage.kernel.epilogue.visitor,
|
||||
problem_size.mn(),
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx,
|
||||
params_C,
|
||||
params_D,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
ptr_Max,
|
||||
ptr_Sum,
|
||||
threadblock_offset.mn(),
|
||||
column_offset,
|
||||
params.problem_sizes_real[problem_idx].mn()
|
||||
);
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.kernel.epilogue.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -116,6 +116,10 @@ foreach(EXAMPLE
|
||||
34_transposed_conv2d
|
||||
35_gemm_softmax
|
||||
36_gather_scatter_fusion
|
||||
37_gemm_layernorm_gemm_fusion
|
||||
38_syr2k_grouped
|
||||
39_gemm_permute
|
||||
41_multi_head_attention
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -98,6 +98,9 @@ template <
|
||||
bool IsHermitianData = false>
|
||||
struct cp_async_diag;
|
||||
|
||||
static const uint32_t OOB_NAN_F16 = 0x7eff;
|
||||
static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
@ -190,8 +193,8 @@ struct cp_async_nan<16, CacheOperation::Always> {
|
||||
cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
|
||||
static __constant__ uint4 OOB_NAN_F16x8 = {0x7eff7eff, 0x7eff7eff,
|
||||
0x7eff7eff, 0x7eff7eff};
|
||||
static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2,
|
||||
OOB_NAN_F16x2, OOB_NAN_F16x2};
|
||||
|
||||
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
|
||||
|
||||
@ -305,7 +308,6 @@ struct cp_async_diag <Element_, true> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
@ -386,6 +388,47 @@ struct cp_async_zfill<SizeInBytes, CacheOperation::Global> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <>
|
||||
struct cp_async_nan<16, CacheOperation::Global> {
|
||||
static int const kSizeInBytes = 16;
|
||||
|
||||
/// Copy with nan fill
|
||||
CUTLASS_DEVICE
|
||||
cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
|
||||
static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2,
|
||||
OOB_NAN_F16x2, OOB_NAN_F16x2};
|
||||
|
||||
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
#if CUTLASS_ENABLE_L2_PREFETCH
|
||||
" @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n"
|
||||
#else
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
#endif
|
||||
" @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n"
|
||||
"}\n"
|
||||
:
|
||||
: "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr),
|
||||
"n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z),
|
||||
"r"(OOB_NAN_F16x8.w));
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(smem_ptr);
|
||||
CUTLASS_UNUSED(global_ptr);
|
||||
CUTLASS_UNUSED(pred_guard);
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
|
||||
#endif
|
||||
}
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
|
||||
@ -126,6 +126,10 @@ struct Mma<
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -188,6 +192,10 @@ struct Mma<
|
||||
);
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -251,6 +259,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -308,6 +320,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -366,6 +382,10 @@ struct Mma<
|
||||
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -423,6 +443,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -486,6 +510,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -543,6 +571,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -600,6 +632,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -657,6 +693,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -711,7 +751,6 @@ struct Mma<
|
||||
|
||||
unsigned const & A = reinterpret_cast<unsigned const &>(a);
|
||||
unsigned const & B = reinterpret_cast<unsigned const &>(b);
|
||||
|
||||
int const *C = reinterpret_cast<int const *>(&c);
|
||||
int *D = reinterpret_cast<int *>(&d);
|
||||
|
||||
@ -720,6 +759,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -777,6 +820,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -834,6 +881,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -891,6 +942,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -954,6 +1009,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1011,6 +1070,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1068,6 +1131,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1125,6 +1192,10 @@ struct Mma<
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1210,11 +1281,19 @@ struct Mma<
|
||||
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions.
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
|
||||
|
||||
@ -587,6 +587,10 @@ struct Mma<
|
||||
"r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1571,6 +1575,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1631,6 +1639,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1691,6 +1703,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1751,6 +1767,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1818,6 +1838,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1878,6 +1902,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1938,6 +1966,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1998,6 +2030,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -2059,6 +2095,10 @@ struct Mma<
|
||||
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -2126,6 +2166,10 @@ struct Mma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
|
||||
|
||||
@ -141,6 +141,10 @@ struct SparseMma<
|
||||
assert(0);
|
||||
}
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -224,6 +228,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -296,6 +304,10 @@ struct SparseMma<gemm::GemmShape<16, 8, 32>, 32, bfloat16_t, layout::RowMajor,
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -368,6 +380,10 @@ struct SparseMma<gemm::GemmShape<16, 8, 16>, 32, tfloat32_t, layout::RowMajor,
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -449,6 +465,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -524,6 +544,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -599,6 +623,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -674,6 +702,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -755,6 +787,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -830,6 +866,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -905,6 +945,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -980,6 +1024,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1061,6 +1109,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1136,6 +1188,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1211,6 +1267,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1286,6 +1346,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1367,6 +1431,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1442,6 +1510,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1517,6 +1589,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
@ -1592,6 +1668,10 @@ struct SparseMma<
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -127,8 +127,8 @@ template <typename T>
|
||||
class complex
|
||||
{
|
||||
public:
|
||||
/// Type alias for scalar type
|
||||
using value_type = T;
|
||||
/// Type alias for scalar type
|
||||
using value_type = T;
|
||||
|
||||
private:
|
||||
//
|
||||
|
||||
@ -268,7 +268,7 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
int64_t filter_size() const {
|
||||
|
||||
return (K * R * S * C);
|
||||
return (K * R * S * C / groups);
|
||||
}
|
||||
|
||||
/// Returns output size in number of elements
|
||||
@ -362,61 +362,128 @@ int implicit_gemm_k_iterations(
|
||||
Operator conv_operator,
|
||||
int threadblock_K,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
|
||||
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
|
||||
GroupMode group_mode = GroupMode::kNone,
|
||||
int threadblock_N = 0) {
|
||||
|
||||
int iterations = 0;
|
||||
|
||||
if (algorithm == IteratorAlgorithm::kFixedChannels) {
|
||||
if (group_mode == GroupMode::kNone) {
|
||||
|
||||
int positions_per_iteration = threadblock_K / problem_size.C;
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration;
|
||||
break;
|
||||
if (algorithm == IteratorAlgorithm::kFixedChannels) {
|
||||
|
||||
default:
|
||||
break;
|
||||
int positions_per_iteration = threadblock_K / problem_size.C;
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (algorithm == IteratorAlgorithm::kFewChannels) {
|
||||
else if (algorithm == IteratorAlgorithm::kFewChannels) {
|
||||
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K;
|
||||
break;
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
int elements_per_split_k_slice = 0;
|
||||
else {
|
||||
int elements_per_split_k_slice = 0;
|
||||
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kDgrad:
|
||||
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
case Operator::kDgrad:
|
||||
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kWgrad:
|
||||
elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
||||
break;
|
||||
case Operator::kWgrad:
|
||||
elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} else if (group_mode == GroupMode::kDepthwise) {
|
||||
int channels_per_cta = threadblock_N;
|
||||
|
||||
if (algorithm == IteratorAlgorithm::kAnalytic) {
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = problem_size.R * problem_size.S *
|
||||
((channels_per_cta + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else { // Group conv
|
||||
|
||||
int channels_per_group = problem_size.C / problem_size.groups;
|
||||
int k_per_group = problem_size.K / problem_size.groups;
|
||||
|
||||
if (algorithm == IteratorAlgorithm::kAnalytic) {
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
|
||||
// In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups
|
||||
if (problem_size.groups != 1) {
|
||||
if (k_per_group < threadblock_N) {
|
||||
iterations *= threadblock_N / k_per_group;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return iterations;
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int implicit_gemm_k_iterations_per_channel(
|
||||
Operator conv_operator,
|
||||
int threadblock_K,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
|
||||
|
||||
int iterations = 0; //0 means not applicable
|
||||
if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) {
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = problem_size.R * problem_size.S;
|
||||
break;
|
||||
|
||||
case Operator::kDgrad:
|
||||
iterations = problem_size.R * problem_size.S;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return iterations;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -537,12 +604,12 @@ void strided_dgrad_starting_coords(
|
||||
// function locals for remainder by fast divmod
|
||||
int pad_h_rem_, pad_w_rem_;
|
||||
|
||||
// start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
||||
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
|
||||
stride_h_divmod.divmod(start_h, r_);
|
||||
|
||||
//start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
||||
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
|
||||
stride_w_divmod.divmod(start_w, s_);
|
||||
|
||||
@ -339,29 +339,46 @@ int implicit_gemm_k_iterations(
|
||||
Operator conv_operator,
|
||||
int threadblock_K,
|
||||
Conv3dProblemSize const &problem_size,
|
||||
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
|
||||
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
|
||||
GroupMode group_mode = GroupMode::kNone,
|
||||
int threadblock_N = 0) {
|
||||
|
||||
int iterations = 0;
|
||||
int elements_per_split_k_slice = 0;
|
||||
if (group_mode == GroupMode::kNone) {
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kDgrad:
|
||||
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kWgrad:
|
||||
elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (group_mode == GroupMode::kDepthwise) {
|
||||
int channels_per_cta = threadblock_N;
|
||||
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kDgrad:
|
||||
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
case Operator::kWgrad:
|
||||
elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
if (algorithm == IteratorAlgorithm::kAnalytic) {
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = problem_size.T * problem_size.R * problem_size.S *
|
||||
((channels_per_cta + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return iterations;
|
||||
|
||||
@ -117,6 +117,14 @@ enum class SplitKMode {
|
||||
kParallel
|
||||
};
|
||||
|
||||
/// Identifies group mode
|
||||
enum class GroupMode {
|
||||
kNone,
|
||||
kSingleGroup, ///< One CTA calculates one group or less
|
||||
kMultipleGroup, ///< One CTA calculates multiple groups
|
||||
kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups)
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace conv
|
||||
|
||||
@ -78,6 +78,7 @@ public:
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm;
|
||||
static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport;
|
||||
static cutlass::conv::GroupMode const kGroupMode = ImplicitGemmKernel::kGroupMode;
|
||||
|
||||
static int const kWarpCount =
|
||||
(ThreadblockShape::kM / WarpShape::kM) *
|
||||
@ -111,6 +112,34 @@ public:
|
||||
return status;
|
||||
}
|
||||
|
||||
// check group conv constraint
|
||||
if (args.problem_size.groups != 1) {
|
||||
if (kGroupMode == conv::GroupMode::kNone) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// C and K should be multiple of groups
|
||||
if (args.problem_size.K % args.problem_size.groups ||
|
||||
args.problem_size.C % args.problem_size.groups) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// split-k is not supported
|
||||
if (args.problem_size.split_k_slices != 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
int k_per_group = args.problem_size.K / args.problem_size.groups;
|
||||
// k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group
|
||||
if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
// ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups
|
||||
if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
||||
if (args.problem_size.K % kAlignmentC)
|
||||
|
||||
@ -45,8 +45,8 @@
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h"
|
||||
#include "cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h"
|
||||
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
|
||||
#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h"
|
||||
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -161,7 +161,7 @@ struct DefaultConv2dFpropFusion <
|
||||
LayoutScaleBias>;
|
||||
|
||||
using SmemIteratorScaleBias =
|
||||
cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
@ -172,7 +172,7 @@ struct DefaultConv2dFpropFusion <
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias<
|
||||
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
|
||||
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
||||
@ -296,7 +296,7 @@ struct DefaultConv2dFpropFusion <
|
||||
LayoutScaleBias>;
|
||||
|
||||
using SmemIteratorScaleBias =
|
||||
cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
@ -307,7 +307,7 @@ struct DefaultConv2dFpropFusion <
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias<
|
||||
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
|
||||
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
||||
|
||||
222
include/cutlass/conv/kernel/default_conv2d_group_fprop.h
Normal file
222
include/cutlass/conv/kernel/default_conv2d_group_fprop.h
Normal file
@ -0,0 +1,222 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv2dGroupFpro
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::GroupMode GroupMode,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultConv2dGroupFprop;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::GroupMode GroupMode,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dGroupFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
GroupMode,
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA,
|
||||
AccessTypeA,
|
||||
GroupMode
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB,
|
||||
AccessTypeB,
|
||||
GroupMode
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv2dProblemSize,
|
||||
GroupMode
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
360
include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h
Normal file
360
include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h
Normal file
@ -0,0 +1,360 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution
|
||||
definitions that combine threadblock-scoped matrix multiply-add with the
|
||||
appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h"
|
||||
#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h"
|
||||
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for fused batch norm and Conv3dFprop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided
|
||||
> struct DefaultConv3dFpropFusion;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassTensorOp convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dFpropFusion <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kAnalytic
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using IteratorScaleBias =
|
||||
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
using SmemIteratorScaleBias =
|
||||
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
|
||||
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
||||
MmaCore::WarpCount::kK>;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorScaleBias,
|
||||
SmemIteratorScaleBias,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
WarpIteratorScaleBias,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementScaleBias,
|
||||
typename LayoutScaleBias,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag
|
||||
>
|
||||
struct DefaultConv3dFpropFusion <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementScaleBias,
|
||||
LayoutScaleBias,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag
|
||||
>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using IteratorScaleBias =
|
||||
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
using SmemIteratorScaleBias =
|
||||
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
|
||||
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias>;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
|
||||
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
|
||||
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
|
||||
MmaCore::WarpCount::kK>;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
arch::CacheOperation::Global,
|
||||
IteratorScaleBias,
|
||||
SmemIteratorScaleBias,
|
||||
arch::CacheOperation::Always,
|
||||
MmaPolicy,
|
||||
WarpIteratorScaleBias,
|
||||
Stages
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
1,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv3dProblemSize
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
218
include/cutlass/conv/kernel/default_depthwise_fprop.h
Normal file
218
include/cutlass/conv/kernel/default_depthwise_fprop.h
Normal file
@ -0,0 +1,218 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped
|
||||
matrix multiply-add with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv2dFprop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = cutlass::sizeof_bits<ElementB>::value / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultDepthwiseFprop;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassSimt convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm,
|
||||
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultDepthwiseFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag, // cutlass::arch::OpMultiplyAdd
|
||||
IteratorAlgorithm::kAnalytic,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize<
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementA,
|
||||
layout::RowMajor,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassSimt,
|
||||
128,
|
||||
sizeof_bits<ElementB>::value,
|
||||
2,
|
||||
MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB,
|
||||
AccessTypeB,
|
||||
cutlass::conv::GroupMode::kDepthwise
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::DepthwiseFpropPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv2dProblemSize,
|
||||
cutlass::conv::GroupMode::kDepthwise
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -62,7 +62,8 @@ template <
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
|
||||
typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem
|
||||
conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode
|
||||
>
|
||||
struct ImplicitGemmConvolution {
|
||||
|
||||
@ -117,6 +118,8 @@ struct ImplicitGemmConvolution {
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
static conv::GroupMode const kGroupMode = GroupMode_;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
@ -198,6 +201,7 @@ struct ImplicitGemmConvolution {
|
||||
int swizzle_log_tile;
|
||||
|
||||
int gemm_k_iterations;
|
||||
int gemm_k_iterations_per_channel;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
typename Mma::IteratorB::Params iterator_B;
|
||||
@ -241,7 +245,12 @@ struct ImplicitGemmConvolution {
|
||||
kConvolutionalOperator,
|
||||
ThreadblockShape::kK,
|
||||
args.problem_size,
|
||||
kIteratorAlgorithm);
|
||||
kIteratorAlgorithm,
|
||||
kGroupMode,
|
||||
ThreadblockShape::kN);
|
||||
|
||||
gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel(
|
||||
kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
@ -286,6 +295,17 @@ struct ImplicitGemmConvolution {
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK;
|
||||
if (kGroupMode != GroupMode::kNone) {
|
||||
if (kGroupMode != GroupMode::kDepthwise) {
|
||||
int k_per_group = params.problem_size.K / params.problem_size.groups;
|
||||
int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group;
|
||||
int channels_per_group = params.problem_size.C / params.problem_size.groups;
|
||||
iterator_A_column_offset += group_idx * channels_per_group;
|
||||
} else {
|
||||
iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN;
|
||||
}
|
||||
}
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
@ -295,7 +315,7 @@ struct ImplicitGemmConvolution {
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
threadblock_tile_idx.k() * Mma::Shape::kK
|
||||
iterator_A_column_offset
|
||||
)
|
||||
);
|
||||
|
||||
@ -327,7 +347,7 @@ struct ImplicitGemmConvolution {
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
|
||||
@ -119,6 +119,8 @@ struct ImplicitGemmConvolutionFusion {
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
static conv::GroupMode const kGroupMode = conv::GroupMode::kNone;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
|
||||
@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
static conv::GroupMode const kGroupMode = conv::GroupMode::kNone;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
@ -488,4 +490,3 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionWithFusedEpilogue {
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
static conv::GroupMode const kGroupMode = conv::GroupMode::kNone;
|
||||
|
||||
/// Wgrad C stride idx for implicit gemm algorithm
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
|
||||
@ -248,7 +248,7 @@ public:
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
CUTLASS_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
@ -263,18 +263,33 @@ public:
|
||||
|
||||
// Move filter_r by stride_h
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
|
||||
#if 0
|
||||
bool check = (filter_r_ < problem_size_.R);
|
||||
|
||||
filter_r_ = check ? filter_r_ : start_r_;
|
||||
next_idx = check ? 1 : 2;
|
||||
reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_);
|
||||
#else
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
" .reg .pred %%p;\n\t"
|
||||
" .reg .s64 t1;\n\t"
|
||||
" setp.lt.s32 %%p, %3, %4;\n\t"
|
||||
" selp.s32 %0, %3, %5, %%p;\n\t"
|
||||
" selp.s32 %1, 1, 2, %%p;\n\t"
|
||||
" selp.s64 t1, %6, %7, %%p;\n\t"
|
||||
" add.s64 %2, %8, t1;\n\t"
|
||||
"}\n"
|
||||
: "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes)
|
||||
: "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_),
|
||||
"l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
// offset pointers by offset_bytes
|
||||
pointer_ += (params_.inc_next[next_idx] - reset_bytes);
|
||||
|
||||
if (next_idx == 2) {
|
||||
if (next_idx == 2) {
|
||||
filter_k_ += params_.filter_k_delta;
|
||||
}
|
||||
|
||||
|
||||
@ -528,7 +528,6 @@ public:
|
||||
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
|
||||
|
||||
return TensorCoord(n, p, q, k);
|
||||
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the output tensor Dy
|
||||
|
||||
@ -321,7 +321,7 @@ public:
|
||||
add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
CUTLASS_DEVICE
|
||||
void advance() {
|
||||
|
||||
int next_idx = 0;
|
||||
@ -336,8 +336,9 @@ public:
|
||||
|
||||
// Move filter_r by stride_h
|
||||
filter_r_ += problem_size_.stride_h;
|
||||
#if 0
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
|
||||
|
||||
next_idx = 1;
|
||||
|
||||
// Restore bytes in q coordinate (Mma in filter s dimenstion)
|
||||
@ -347,12 +348,25 @@ public:
|
||||
|
||||
// Restore filter_r
|
||||
filter_r_ = start_r_;
|
||||
|
||||
|
||||
next_idx = 2;
|
||||
|
||||
|
||||
// Restore bytes in p and q coordinate (Mma in filter s and r dimenstion)
|
||||
reset_bytes = reset_bytes_r_;
|
||||
}
|
||||
#else
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
" .reg .pred %%p;\n\t"
|
||||
" setp.lt.s32 %%p, %3, %4;\n\t"
|
||||
" selp.s32 %0, %3, %5, %%p;\n\t"
|
||||
" selp.s32 %1, 1, 2, %%p;\n\t"
|
||||
" selp.s64 %2, %6, %7, %%p;\n\t"
|
||||
"}\n"
|
||||
: "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes)
|
||||
: "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_),
|
||||
"l"(reset_bytes_s_), "l"(reset_bytes_r_));
|
||||
#endif
|
||||
}
|
||||
|
||||
// offset pointers by offset_bytes
|
||||
|
||||
@ -67,7 +67,8 @@ template <
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>,
|
||||
conv::GroupMode GroupMode_ = conv::GroupMode::kNone
|
||||
>
|
||||
class Conv2dFpropActivationTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@ -89,6 +90,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
static conv::GroupMode const kGroupMode = GroupMode_;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
@ -119,6 +121,11 @@ private:
|
||||
int filter_c_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int filter_c_init_;
|
||||
int group_idx_offset_;
|
||||
int channels_per_group_;
|
||||
int crs_cnt_;
|
||||
int crs_per_group_;
|
||||
|
||||
int offset_n_[ThreadMap::Iterations::kStrided];
|
||||
int offset_p_[ThreadMap::Iterations::kStrided];
|
||||
@ -137,6 +144,8 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
crs_cnt_(0),
|
||||
group_idx_offset_(0),
|
||||
filter_c_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
@ -145,6 +154,12 @@ public:
|
||||
|
||||
filter_c_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
if (kGroupMode != conv::GroupMode::kNone) {
|
||||
filter_c_init_ = filter_c_;
|
||||
channels_per_group_ = problem_size_.C / problem_size_.groups;
|
||||
crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
@ -182,6 +197,10 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
if (kGroupMode != conv::GroupMode::kNone) {
|
||||
++crs_cnt_;
|
||||
}
|
||||
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
@ -192,8 +211,19 @@ public:
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
|
||||
|
||||
if (kGroupMode == conv::GroupMode::kNone) {
|
||||
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
|
||||
} else {
|
||||
if (crs_cnt_ == crs_per_group_) {
|
||||
// moves to next group
|
||||
crs_cnt_ = 0;
|
||||
++group_idx_offset_;
|
||||
filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_;
|
||||
} else {
|
||||
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activations tensor X that is currently pointed to
|
||||
|
||||
@ -66,7 +66,8 @@ template <
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>,
|
||||
conv::GroupMode GroupMode_ = conv::GroupMode::kNone
|
||||
>
|
||||
class Conv2dFpropFilterTileAccessIteratorAnalytic {
|
||||
public:
|
||||
@ -88,6 +89,7 @@ public:
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
static conv::GroupMode const kGroupMode = GroupMode_;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
@ -118,8 +120,14 @@ private:
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int filter_c_;
|
||||
int filter_c_init_;
|
||||
int crs_cnt_;
|
||||
int crs_per_group_;
|
||||
int group_idx_offset_c_;
|
||||
int channels_per_group_;
|
||||
|
||||
int offset_k_[ThreadMap::Iterations::kStrided];
|
||||
int group_idx_offset_k_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
@ -134,6 +142,8 @@ public:
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
crs_cnt_(0),
|
||||
group_idx_offset_c_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0),
|
||||
filter_c_(0) {
|
||||
@ -142,9 +152,23 @@ public:
|
||||
|
||||
filter_c_ = threadblock_offset.row() + thread_coord.contiguous();
|
||||
|
||||
if (kGroupMode != conv::GroupMode::kNone) {
|
||||
filter_c_init_ = filter_c_;
|
||||
if (kGroupMode == conv::GroupMode::kDepthwise){
|
||||
channels_per_group_ = 1;
|
||||
crs_per_group_ = problem_size_.S * problem_size_.R;
|
||||
} else {
|
||||
channels_per_group_ = problem_size_.C / problem_size_.groups;
|
||||
crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) {
|
||||
group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups);
|
||||
}
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
@ -168,6 +192,10 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
if (kGroupMode != conv::GroupMode::kNone) {
|
||||
++crs_cnt_;
|
||||
}
|
||||
|
||||
++filter_s_;
|
||||
if (filter_s_ < problem_size_.S) {
|
||||
return;
|
||||
@ -179,8 +207,21 @@ public:
|
||||
return;
|
||||
}
|
||||
filter_r_ = 0;
|
||||
|
||||
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
|
||||
|
||||
if (kGroupMode == conv::GroupMode::kNone) {
|
||||
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
|
||||
} else {
|
||||
if (crs_cnt_ == crs_per_group_) {
|
||||
crs_cnt_ = 0;
|
||||
filter_c_ = filter_c_init_;
|
||||
if (kGroupMode != conv::GroupMode::kDepthwise) {
|
||||
// moves to next group
|
||||
++group_idx_offset_c_;
|
||||
}
|
||||
} else {
|
||||
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the filter tensor W that is currently pointed to
|
||||
@ -200,8 +241,14 @@ public:
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K &&
|
||||
coord.c() < problem_size_.C;
|
||||
if (kGroupMode == conv::GroupMode::kNone) {
|
||||
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
|
||||
} else if (kGroupMode == conv::GroupMode::kDepthwise) {
|
||||
return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE.
|
||||
} else {
|
||||
return coord.n() < problem_size_.K && coord.c() < channels_per_group_ &&
|
||||
group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_];
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
|
||||
@ -554,20 +554,20 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
// next S
|
||||
inc_next[0] = conv_sign * (
|
||||
layout.stride()[0] * problem_size.dilation_w
|
||||
(int64_t)layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] = conv_sign * (
|
||||
layout.stride()[1] * problem_size.dilation_h
|
||||
- (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
|
||||
(int64_t)layout.stride()[1] * problem_size.dilation_h
|
||||
- (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
inc_next[2] = (
|
||||
threadblock_shape.column() * problem_size.split_k_slices
|
||||
- conv_sign * (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h
|
||||
- conv_sign * (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
|
||||
- conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h
|
||||
- conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
@ -614,12 +614,12 @@ struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams {
|
||||
|
||||
// next S
|
||||
inc_next[0] = conv_sign * (
|
||||
layout.stride()[0] * problem_size.dilation_w
|
||||
(int64_t)layout.stride()[0] * problem_size.dilation_w
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] = conv_sign * (
|
||||
layout.stride()[1] * problem_size.dilation_h
|
||||
(int64_t)layout.stride()[1] * problem_size.dilation_h
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
@ -670,18 +670,18 @@ struct Conv2dDgradFilterIteratorOptimizedParams {
|
||||
TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
|
||||
inc_next_rs =
|
||||
( layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
( (int64_t)layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
inc_next_k =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
|
||||
- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2]
|
||||
- (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
@ -730,26 +730,26 @@ struct Conv2dStridedDgradFilterIteratorOptimizedParams {
|
||||
|
||||
// next S
|
||||
inc_next[0] =
|
||||
( layout.stride()[0] * problem_size.stride_w
|
||||
( (int64_t)layout.stride()[0] * problem_size.stride_w
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next R
|
||||
inc_next[1] =
|
||||
( layout.stride()[1] * problem_size.stride_h
|
||||
( (int64_t)layout.stride()[1] * problem_size.stride_h
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// next K
|
||||
inc_next[2] =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
|
||||
threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2]
|
||||
//- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
|
||||
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
// offset in units of bytes to move the pointer in backward direction
|
||||
reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
|
||||
reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2]
|
||||
* element_size_bits / 8;
|
||||
|
||||
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
@ -800,13 +800,13 @@ struct Conv2dWgradOutputGradientIteratorOptimizedParams {
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
// Incremental offsets in unites of bytes (number of elements) * sizeof_bits<Element>::value / 8
|
||||
offset_next_strided = (threadmap_delta.strided() * layout.stride()[0])
|
||||
offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
|
||||
offset_next_contiguous = (threadmap_delta.contiguous())
|
||||
* element_size_bits / 8;
|
||||
|
||||
inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0])
|
||||
inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
}
|
||||
};
|
||||
@ -891,4 +891,3 @@ struct PredicatedScaleBiasVectorAccessIteratorParams {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -104,6 +104,11 @@ public:
|
||||
return TileAccessIterator::getParams(problem_size, layout);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
tile_access_iterator_.set_iteration_index(index);
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -304,8 +304,8 @@ struct Conv3dDgradOutputGradientIteratorOptimizedParams {
|
||||
// logical offset added to internal channel counter - units are elements, not bytes
|
||||
filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters object for Conv2d DGRAD Filter (w) iterator
|
||||
@ -343,18 +343,18 @@ struct Conv3dDgradFilterIteratorOptimizedParams {
|
||||
TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter",
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
inc_next_strided = (layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8;
|
||||
|
||||
inc_next_trs =
|
||||
( layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3]
|
||||
( (int64_t)layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
inc_next_k =
|
||||
(
|
||||
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[3]
|
||||
- (problem_size.T * problem_size.R * problem_size.S - 1) * layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3]
|
||||
threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3]
|
||||
- (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0]
|
||||
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3]
|
||||
) * element_size_bits / 8;
|
||||
|
||||
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
|
||||
@ -408,13 +408,13 @@ struct Conv3dWgradOutputGradientIteratorOptimizedParams {
|
||||
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
|
||||
|
||||
// Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8
|
||||
offset_next_strided = (threadmap_delta.strided() * layout.stride()[0])
|
||||
offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
|
||||
offset_next_contiguous = (threadmap_delta.contiguous())
|
||||
* element_size_bits / 8;
|
||||
|
||||
inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0])
|
||||
inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0])
|
||||
* element_size_bits / 8;
|
||||
|
||||
// Precompute several quantities for fast modulo arithmetic.
|
||||
|
||||
336
include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h
Normal file
336
include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h
Normal file
@ -0,0 +1,336 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Transformation applied to A operand
|
||||
typename TransformA_ = NumericArrayConverter<
|
||||
typename SmemIteratorA_::Element,
|
||||
typename IteratorA_::Element,
|
||||
IteratorA_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to A operand
|
||||
typename TransformB_ = NumericArrayConverter<
|
||||
typename SmemIteratorB_::Element,
|
||||
typename IteratorB_::Element,
|
||||
IteratorB_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase<Shape_, Policy_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = gemm::threadblock::MmaBase<Shape_, Policy_, 2>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DepthwiseFpropPipelined(
|
||||
typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC &accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
FragmentC const &src_accum, ///< source accumulator tile
|
||||
int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel
|
||||
TransformA transform_A = TransformA(), ///< transformation applied to A fragment
|
||||
TransformB transform_B = TransformB()) { ///< transformation applied to B fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
// Depthwise specific
|
||||
int channel_start_index = 0;
|
||||
int rs_plane_idx = 0;
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
if(rs_plane_idx == gemm_k_iterations_per_channel - 1){
|
||||
// Reset interation index.
|
||||
iterator_B.set_iteration_index(0);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(rs_plane_idx == gemm_k_iterations_per_channel - 1){
|
||||
// Move to next set of filter groups.
|
||||
channel_start_index += Base::kWarpGemmIterations;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
|
||||
warp_frag_B[warp_mma_k % 2], accum);
|
||||
}
|
||||
|
||||
rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,337 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data
|
||||
layout of the global memory fragments, data types, and internal tile sizes.
|
||||
|
||||
Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
#include "cutlass/gemm/threadblock/mma_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/mma_singlestage.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/conv/warp/mma_depthwise_simt.h"
|
||||
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator
|
||||
typename Shape,
|
||||
/// Shape of warp-level matrix multiply operator
|
||||
typename WarpShape,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Element data type of A operand
|
||||
typename ElementA,
|
||||
/// Layout of operand A
|
||||
typename LayoutA,
|
||||
/// Element data type of B operand
|
||||
typename ElementB,
|
||||
/// Layout of operand B
|
||||
typename LayoutB,
|
||||
/// Data type of accumulator
|
||||
typename ElementC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC,
|
||||
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
|
||||
typename OperatorClass,
|
||||
/// Size of a warp-scoped per thread access
|
||||
int kLaneAccessSizeA_ = 0,
|
||||
/// Size of a warp-scoped per thread access
|
||||
int kLaneAccessSizeB_ = 0,
|
||||
/// Number of stages
|
||||
int Stages = 2,
|
||||
/// Operation performed by MMA
|
||||
typename Operator = typename platform::conditional<
|
||||
(platform::is_same<OperatorClass,
|
||||
cutlass::arch::OpClassTensorOp>::value) &&
|
||||
(platform::is_same<ElementA, int8_t>::value ||
|
||||
platform::is_same<ElementA, int4b_t>::value ||
|
||||
platform::is_same<ElementA, uint8_t>::value ||
|
||||
platform::is_same<ElementA, uint4b_t>::value),
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA =
|
||||
cutlass::arch::CacheOperation::Global,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB =
|
||||
cutlass::arch::CacheOperation::Global,
|
||||
/// per-element transformation for elements of A
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// per-element transformation for elements of B
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
bool IsComplex = false // (is_complex<ElementA>::value || is_complex<ElementB>::value)
|
||||
>
|
||||
struct DepthwiseMmaCoreWithLaneAccessSize;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator
|
||||
typename Shape,
|
||||
/// Shape of warp-level matrix multiply operator
|
||||
typename WarpShape,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Element data type of A operand
|
||||
typename ElementA,
|
||||
/// Layout of operand A
|
||||
typename LayoutA,
|
||||
/// Element data type of B operand
|
||||
typename ElementB,
|
||||
/// Layout of operand B
|
||||
typename LayoutB,
|
||||
/// Data type of accumulator
|
||||
typename ElementC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC,
|
||||
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
|
||||
typename OperatorClass,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// per-element transformation for elements of A
|
||||
ComplexTransform TransformA,
|
||||
/// per-element transformation for elements of B
|
||||
ComplexTransform TransformB,
|
||||
bool IsComplex
|
||||
>
|
||||
struct DepthwiseMmaCoreWithLaneAccessSize<
|
||||
Shape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor,
|
||||
CacheOpA, CacheOpB, TransformA, TransformB, IsComplex
|
||||
> : cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
Shape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
OperatorClass, Stages, Operator, AccumulatorsInRowMajor,
|
||||
CacheOpA, CacheOpB, TransformA, TransformB, IsComplex
|
||||
> {};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: column-major
|
||||
/// Operator: simt class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Size of a warp-scoped per thread access (a value of -1 indicates the default)
|
||||
int kLaneAccessSizeA_,
|
||||
/// Size of a warp-scoped per thread access (a value of -1 indicates the default)
|
||||
int kLaneAccessSizeB_,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DepthwiseMmaCoreWithLaneAccessSize<Shape_,
|
||||
WarpShape_,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
ElementA_,
|
||||
layout::RowMajor,
|
||||
ElementB_,
|
||||
layout::ColumnMajor,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
kLaneAccessSizeA_,
|
||||
kLaneAccessSizeB_,
|
||||
2,
|
||||
Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore<Shape_,
|
||||
WarpShape_,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
ElementA_,
|
||||
layout::RowMajor,
|
||||
ElementB_,
|
||||
layout::ColumnMajor,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
2,
|
||||
Operator_> {
|
||||
using Base = cutlass::gemm::threadblock::DefaultMmaCore<Shape_,
|
||||
WarpShape_,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
ElementA_,
|
||||
layout::RowMajor,
|
||||
ElementB_,
|
||||
layout::ColumnMajor,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
2,
|
||||
Operator_>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
|
||||
static int const kLaneAccessSizeA = kLaneAccessSizeA_;
|
||||
static int const kLaneAccessSizeB = kLaneAccessSizeB_;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0,
|
||||
"Size of a warp-scoped per thread access should be larger then ZERO" );
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) &&
|
||||
!(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
|
||||
);
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
|
||||
|
||||
static int const kElementsPerAccess = 1;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::ColumnMajor;
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory are same as base class
|
||||
//
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level op
|
||||
static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m<WarpShape>();
|
||||
static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
|
||||
static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
|
||||
static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
|
||||
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
|
||||
"WarpShape must be divisible by ThreadTile shape.");
|
||||
static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
|
||||
static const int numElementsA = kLaneAccessSizeA / sizeof_bits<ElementA>::value;
|
||||
static const int numElementsB = kLaneAccessSizeB / sizeof_bits<ElementB>::value;
|
||||
static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
|
||||
static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
|
||||
|
||||
static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value);
|
||||
static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value);
|
||||
|
||||
static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN),
|
||||
"Padding must be divisible by Lane");
|
||||
|
||||
// these should have max of thread tile also
|
||||
using LaneMmaShape = cutlass::gemm::GemmShape<
|
||||
LaneM,
|
||||
LaneN,
|
||||
1>;
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
|
||||
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
|
||||
LaneMmaShape
|
||||
>;
|
||||
|
||||
using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt<
|
||||
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
ElementA, /// Data type of A elements
|
||||
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
|
||||
ElementB, /// Data type of B elements
|
||||
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
|
||||
ElementC, /// Element type of C matrix
|
||||
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
|
||||
Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy)
|
||||
>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
|
||||
MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
|
||||
WarpCount::kK
|
||||
>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
@ -64,7 +64,7 @@
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
|
||||
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
|
||||
#include "cutlass/conv/warp/scale_bias_relu_transform.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -139,6 +139,13 @@ class MmaFpropFusionBase {
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
static_assert((kWarpGemmIterations % 2) == 0,
|
||||
"Inner loop iteration must be an even number.");
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
@ -319,7 +326,7 @@ class ImplicitGemmFpropFusionMultistage
|
||||
using Policy = Policy_;
|
||||
///< Base class
|
||||
using Base = MmaFpropFusionBase<Shape_, typename IteratorScaleBias::Element,
|
||||
typename IteratorScaleBias::Layout, Policy_,
|
||||
typename IteratorScaleBias::Layout, Policy,
|
||||
WarpIteratorScaleBias, Stages>;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
@ -518,6 +525,8 @@ public:
|
||||
IteratorScaleBias iterator_A_scale_bias,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum,
|
||||
///< number of iterations per channel
|
||||
int gemm_k_iterations_per_channel = 0,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
@ -116,10 +116,6 @@ public:
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
@ -272,6 +268,8 @@ public:
|
||||
IteratorB iterator_B,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum,
|
||||
///< number of iterations per channel
|
||||
int gemm_k_iterations_per_channel = 0,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
@ -297,7 +295,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
@ -322,7 +320,7 @@ public:
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
|
||||
@ -188,6 +188,7 @@ public:
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
FragmentC const &src_accum, ///< source accumulator tile
|
||||
int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel
|
||||
TransformA transform_A = TransformA(), ///< transformation applied to A fragment
|
||||
TransformB transform_B = TransformB()) { ///< transformation applied to B fragment
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
|
||||
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
|
||||
#include "cutlass/conv/warp/scale_bias_relu_transform.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -138,6 +138,13 @@ class MmaWgradFusionBase {
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
static_assert((kWarpGemmIterations % 2) == 0,
|
||||
"Inner loop iteration must be an even number.");
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
@ -306,10 +313,6 @@ class ImplicitGemmWgradFusionMultistage
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
@ -470,6 +473,8 @@ public:
|
||||
IteratorScaleBias iterator_B_scale_bias,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum,
|
||||
///< number of iterations per channel
|
||||
int gemm_k_iterations_per_channel = 0,
|
||||
///< Imaginary strides used for planar-complex only - ignored here
|
||||
int64_t imag_stride_A = 0,
|
||||
int64_t imag_stride_B = 0) {
|
||||
|
||||
@ -113,12 +113,9 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_;
|
||||
|
||||
/// Size of tensor
|
||||
Conv2dProblemSize problem_size_;
|
||||
|
||||
int filter_c_;
|
||||
int filter_r_;
|
||||
int filter_s_;
|
||||
int problem_size_trs;
|
||||
int problem_size_c;
|
||||
int filter_trs_;
|
||||
|
||||
TensorCoord thread_offset_;
|
||||
|
||||
@ -140,10 +137,43 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
/// Initial offset of threadblock
|
||||
TensorCoord const &threadblock_offset)
|
||||
: params_(params),
|
||||
problem_size_(problem_size),
|
||||
filter_c_(0),
|
||||
filter_r_(0),
|
||||
filter_s_(0) {
|
||||
problem_size_trs(problem_size.R * problem_size.S),
|
||||
problem_size_c(problem_size.C),
|
||||
filter_trs_(0) {
|
||||
pointer_ = (thread_id < kThreads)
|
||||
? reinterpret_cast<BytePointer>(
|
||||
const_cast<NonConstPointer>(scale_pointer))
|
||||
: reinterpret_cast<BytePointer>(
|
||||
const_cast<NonConstPointer>(bias_pointer));
|
||||
|
||||
// Per-thread offset in logical coordinates of tensor
|
||||
int thread_base = (thread_id < kThreads) ? 0 : kThreads;
|
||||
|
||||
thread_offset_ =
|
||||
threadblock_offset +
|
||||
TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0);
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
/// Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
/// Extent of tensor
|
||||
Conv3dProblemSize const &problem_size,
|
||||
/// Pointer to the start of the scale vector
|
||||
ConstPointer scale_pointer,
|
||||
/// Pointer to the start of the bias vector
|
||||
ConstPointer bias_pointer,
|
||||
/// ID of each participating thread
|
||||
int thread_id,
|
||||
/// Initial offset of threadblock
|
||||
TensorCoord const &threadblock_offset)
|
||||
: params_(params),
|
||||
problem_size_trs(problem_size.T * problem_size.R * problem_size.S),
|
||||
problem_size_c(problem_size.C),
|
||||
filter_trs_(0) {
|
||||
pointer_ = (thread_id < kThreads)
|
||||
? reinterpret_cast<BytePointer>(
|
||||
const_cast<NonConstPointer>(scale_pointer))
|
||||
@ -177,6 +207,22 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
/// Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
/// Extent of tensor
|
||||
Conv3dProblemSize const &problem_size,
|
||||
/// Pointer to start of scale vector
|
||||
ConstPointer scale_pointer,
|
||||
/// Pointer to start of scale vector
|
||||
ConstPointer bias_pointer,
|
||||
///< ID of each participating thread
|
||||
int thread_id)
|
||||
: PredicatedScaleBiasVectorAccessIterator(params, problem_size,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int index) {}
|
||||
@ -209,16 +255,10 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// moves to the next tile
|
||||
++filter_s_;
|
||||
if (filter_s_ == problem_size_.S) {
|
||||
filter_s_ = 0;
|
||||
++filter_r_;
|
||||
|
||||
if (filter_r_ < problem_size_.R) {
|
||||
} else {
|
||||
filter_r_ = 0;
|
||||
add_tile_offset(TensorCoord(1, 0));
|
||||
}
|
||||
++filter_trs_;
|
||||
if (filter_trs_ == problem_size_trs) {
|
||||
filter_trs_ = 0;
|
||||
add_tile_offset(TensorCoord(1, 0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -248,7 +288,7 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
"}\n" : "+r"(enabled) :"n"(kThreads * 2));
|
||||
#endif
|
||||
|
||||
return ((thread_offset_.contiguous() < problem_size_.C) && enabled);
|
||||
return ((thread_offset_.contiguous() < problem_size_c) && enabled);
|
||||
}
|
||||
};
|
||||
|
||||
@ -322,6 +362,25 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
layout::PitchLinearCoord(threadblock_offset.column(),
|
||||
threadblock_offset.row())) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
///< Precomputed parameters object
|
||||
Params const ¶ms,
|
||||
///< Extent of tensor
|
||||
Conv3dProblemSize const &problem_size,
|
||||
///< Pointer to the start of the scale vector
|
||||
ConstPointer scale_pointer,
|
||||
///< Pointer to the start of the bias vector
|
||||
ConstPointer bias_pointer,
|
||||
///< ID of each participating thread
|
||||
int thread_id,
|
||||
///< Initial offset of threadblock
|
||||
TensorCoord const &threadblock_offset)
|
||||
: iterator_(params, problem_size, scale_pointer, bias_pointer,
|
||||
thread_id,
|
||||
layout::PitchLinearCoord(threadblock_offset.column(),
|
||||
threadblock_offset.row())) {}
|
||||
|
||||
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
@ -335,6 +394,18 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedScaleBiasVectorAccessIterator(
|
||||
Params const ¶ms, ///< Precomputed parameters object
|
||||
Conv3dProblemSize const &problem_size, ///< Extent of tensor
|
||||
ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
|
||||
ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
|
||||
int thread_id ///< ID of each participating thread
|
||||
)
|
||||
: PredicatedScaleBiasVectorAccessIterator(params, problem_size,
|
||||
scale_pointer, bias_pointer,
|
||||
thread_id, make_Coord(0, 0)) {}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
||||
|
||||
@ -157,7 +157,6 @@ struct StridedDgradIdentityThreadblockSwizzle :
|
||||
split_k_slices);
|
||||
}
|
||||
|
||||
|
||||
/// Returns the shape of the problem in units of logical tiles
|
||||
/// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape())
|
||||
private:
|
||||
|
||||
163
include/cutlass/conv/warp/mma_depthwise_simt.h
Normal file
163
include/cutlass/conv/warp/mma_depthwise_simt.h
Normal file
@ -0,0 +1,163 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/gemm/thread/mma.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_simt_policy.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_simt.h"
|
||||
#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
|
||||
typename Policy_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK = 1,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaDepthwiseSimt
|
||||
: public cutlass::gemm::warp::
|
||||
MmaSimt<Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, Policy_> {
|
||||
using Base = cutlass::gemm::warp::
|
||||
MmaSimt<Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, Policy_>;
|
||||
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_; // < 64, 16 , 8>
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
|
||||
/// Hard-coded for now
|
||||
using ArchTag = arch::Sm50;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
public:
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator<
|
||||
MatrixShape<Policy::LaneMmaShape::kK, Shape::kN>,
|
||||
cutlass::gemm::Operand::kB,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
Policy,
|
||||
PartitionsK,
|
||||
Shape::kK
|
||||
>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentB = FragmentB;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaDepthwiseSimt():Base() {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
255
include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h
Normal file
255
include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h
Normal file
@ -0,0 +1,255 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT
|
||||
instructions
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_simt_policy.h"
|
||||
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions
|
||||
///
|
||||
/// concept: MutableRandomAccessContiguousTileIteratorConcept
|
||||
///
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
cutlass::gemm::Operand Operand,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
|
||||
typename Policy_,
|
||||
/// Number of partitions along K dimension - used in sliced-K
|
||||
int PartitionsK = 1,
|
||||
/// Group Size along kPartition - used in sliced-K
|
||||
int PartitionGroupSize = 1
|
||||
>
|
||||
class DepthwiseMmaSimtTileIterator;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for B operands of row-major layouts
|
||||
///
|
||||
/// Concept: MutableRandomAccessContiguousTileIteratorConcept
|
||||
///
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
|
||||
typename Policy_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK,
|
||||
/// Group Size along kPartition - used in sliced-K
|
||||
int PartitionGroupSize>
|
||||
class DepthwiseMmaSimtTileIterator<Shape_,
|
||||
cutlass::gemm::Operand::kB,
|
||||
Element_,
|
||||
layout::RowMajor,
|
||||
Policy_,
|
||||
PartitionsK,
|
||||
PartitionGroupSize>
|
||||
: public cutlass::gemm::warp::MmaSimtTileIterator<Shape_,
|
||||
cutlass::gemm::Operand::kB,
|
||||
Element_,
|
||||
layout::RowMajor,
|
||||
Policy_,
|
||||
PartitionsK,
|
||||
PartitionGroupSize> {
|
||||
|
||||
using Base = cutlass::gemm::warp::MmaSimtTileIterator<Shape_,
|
||||
cutlass::gemm::Operand::kB,
|
||||
Element_,
|
||||
layout::RowMajor,
|
||||
Policy_,
|
||||
PartitionsK,
|
||||
PartitionGroupSize>;
|
||||
public:
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Operand tag
|
||||
static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of policy
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// Decomposition of elements among threads
|
||||
using Policy = Policy_;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = typename Base::TensorRef;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
/// Thread-level shape of a fragment
|
||||
using ThreadShape = typename Base::ThreadShape;
|
||||
|
||||
/// Number of individual loads
|
||||
using Iterations = typename Base::Iterations;
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = typename Base::Fragment;
|
||||
|
||||
static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim");
|
||||
|
||||
private:
|
||||
|
||||
MatrixCoord lane_offset_;
|
||||
int channel_idx_;
|
||||
int base_channel_idx_;
|
||||
int warps_n_;
|
||||
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseMmaSimtTileIterator():Base() { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseMmaSimtTileIterator(
|
||||
TensorRef ref,
|
||||
int lane_id
|
||||
) : Base(ref, lane_id) {
|
||||
|
||||
// compute offset based on thread ID and lane layout
|
||||
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
|
||||
|
||||
warps_n_ = -1;
|
||||
channel_idx_ = 0;
|
||||
base_channel_idx_ = 0;
|
||||
lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN);
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
|
||||
|
||||
if(warps_n_ == -1){
|
||||
warps_n_ = coord.column();
|
||||
}
|
||||
|
||||
Base::add_tile_offset(coord);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator. (vector loads)
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
|
||||
Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
|
||||
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < Iterations::kRow; ++k) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Iterations::kColumn; ++n) {
|
||||
|
||||
void const *ptr = this->ref_.data() +
|
||||
this->ref_.offset({-(channel_idx_ - base_channel_idx_),
|
||||
n * Policy::WarpShape::kColumn}) +
|
||||
pointer_offset / Policy::LaneMmaShape::kN;
|
||||
|
||||
// Base_k of a warp + Base_k of current threads.
|
||||
int thread_k_base_idx =
|
||||
warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column();
|
||||
|
||||
if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) {
|
||||
// Depthwise kernel would only do computation when channel == k.
|
||||
// Loads an element when the current computation channel == the k corresponding to this thread.
|
||||
arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr);
|
||||
} else {
|
||||
// Reduce SMEM load
|
||||
dst_ptr[n + k * Iterations::kColumn].fill(Element(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Notify the iterator which k-group it is currently pointing to.
|
||||
///
|
||||
/// This does not advance the iterator. Rather, it overrides its internal
|
||||
/// tracking with constant-valued k-group index
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
if(k_group % PartitionGroupSize == 0 && k_group != 0){
|
||||
base_channel_idx_ = k_group;
|
||||
}
|
||||
channel_idx_ = k_group;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -101,7 +101,7 @@ struct FpropScaleBiasReluTransform {
|
||||
"}\n"
|
||||
: "=r"(ptr_activations[0])
|
||||
: "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]),
|
||||
"r"(ptr_scale_bias[1]), "n"(0x7eff7eff));
|
||||
"r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2));
|
||||
#else
|
||||
// TODO: write emulation code
|
||||
assert(0);
|
||||
@ -151,8 +151,8 @@ struct WgradScaleBiasReluTransform {
|
||||
#if 1
|
||||
// CUDA + PTX version
|
||||
|
||||
bool h1_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].x) == 0x7eff);
|
||||
bool h2_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].y) == 0x7eff);
|
||||
bool h1_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16);
|
||||
bool h2_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16);
|
||||
|
||||
// Apply per channel scale+bias+relu if the data is not a special NaN
|
||||
// (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0.
|
||||
@ -161,7 +161,7 @@ struct WgradScaleBiasReluTransform {
|
||||
// out-of-bound because C x R x S can be an odd number.
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
" fma.rn.f16x2.relu %0 , %1, %2, %3;\n"
|
||||
" fma.rn.f16x2.relu %0, %1, %2, %3;\n"
|
||||
"}"
|
||||
: "=r"(reinterpret_cast<uint32_t &>(ptr_activations[0]))
|
||||
: "r"(ptr_scale_bias[0]), "r"(reinterpret_cast<uint32_t &>(ptr_activations[0])),
|
||||
@ -195,7 +195,7 @@ struct WgradScaleBiasReluTransform {
|
||||
"}\n"
|
||||
: "=r"(reinterpret_cast<uint32_t &>(ptr_activations[0]))
|
||||
: "r"(ptr_scale_bias[0]), "r"(reinterpret_cast<uint32_t &>(ptr_activations[0])),
|
||||
"r"(ptr_scale_bias[1]), "n"(0x7eff), "n"(0xffff0000), "n"(0x0000ffff));
|
||||
"r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff));
|
||||
#endif
|
||||
#else
|
||||
// TODO: write emulation code
|
||||
|
||||
@ -43,7 +43,7 @@
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0)
|
||||
#define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr)
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
|
||||
@ -192,4 +192,3 @@ CUTLASS_HOST_DEVICE bool thread0() {
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -98,6 +98,32 @@ struct ReLu<Array<T, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
// Leaky Relu operator
|
||||
template <typename T>
|
||||
struct LeakyReLU {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value, T const & alpha_recip) const {
|
||||
T res = value > T(0) ? value : value * alpha_recip;
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct LeakyReLU<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
|
||||
Array<T, N> y;
|
||||
LeakyReLU<T> leaky_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < int(rhs.size()); ++i) {
|
||||
y[i] = leaky_op(rhs[i], alpha_recip);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
// Tanh operator
|
||||
template <typename T>
|
||||
struct Tanh {
|
||||
@ -135,32 +161,6 @@ struct Tanh<Array<half_t, N>> {
|
||||
}
|
||||
};
|
||||
|
||||
// Leaky Relu operator
|
||||
template <typename T>
|
||||
struct LeakyReLU {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value, T const & alpha_recip) const {
|
||||
T res = value > T(0) ? value : value * alpha_recip;
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct LeakyReLU<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
|
||||
Array<T, N> y;
|
||||
LeakyReLU<T> leaky_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < int(rhs.size()); ++i) {
|
||||
y[i] = leaky_op(rhs[i], alpha_recip);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
// Sigmoid operator
|
||||
template <typename T>
|
||||
struct Sigmoid {
|
||||
|
||||
@ -157,7 +157,7 @@ public:
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
skip_elementwise_ = true;
|
||||
}
|
||||
|
||||
@ -65,6 +65,8 @@
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -79,7 +81,8 @@ template <
|
||||
typename WarpMmaSimt_,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
bool ScatterD = false,
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
struct DefaultEpilogueSimt {
|
||||
|
||||
@ -109,7 +112,8 @@ struct DefaultEpilogueSimt {
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
ScatterD,
|
||||
PermuteDLayout
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
||||
@ -310,7 +314,6 @@ struct DefaultEpilogueSimtAffineRankN {
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -74,6 +74,8 @@
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -166,7 +168,7 @@ template <
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<float, int32_t, 4, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
|
||||
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
@ -265,7 +267,7 @@ struct DefaultIteratorsTensorOp<
|
||||
layout::RowMajor
|
||||
>;
|
||||
|
||||
using WarpTileIterator = typename cutlass::platform::conditional<
|
||||
using WarpTileIterator = typename platform::conditional<
|
||||
(ThreadblockShape::kN == 256),
|
||||
WarpTileIteratorNotMixed,
|
||||
WarpTileIteratorMixed>::type;
|
||||
@ -284,7 +286,7 @@ struct DefaultIteratorsTensorOp<
|
||||
int32_t
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = typename cutlass::platform::conditional<
|
||||
using SharedLoadIterator = typename platform::conditional<
|
||||
(ThreadblockShape::kN == 256),
|
||||
SharedLoadIteratorNotMixed,
|
||||
SharedLoadIteratorMixed>::type;
|
||||
@ -302,7 +304,8 @@ template <
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
bool ScatterD = false,
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
struct DefaultEpilogueTensorOp {
|
||||
|
||||
@ -334,6 +337,7 @@ struct DefaultEpilogueTensorOp {
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD,
|
||||
PermuteDLayout,
|
||||
UseCUDAStore
|
||||
>;
|
||||
|
||||
@ -570,7 +574,6 @@ struct DefaultEpilogueTensorOpAffineRankN {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps which uses
|
||||
/// intereleaved output layout. For this case, shared memory is not needed.
|
||||
template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,
|
||||
|
||||
@ -66,6 +66,8 @@
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -81,7 +83,8 @@ template <
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
bool ScatterD = false,
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
struct DefaultEpilogueVoltaTensorOp {
|
||||
|
||||
@ -111,7 +114,8 @@ struct DefaultEpilogueVoltaTensorOp {
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
ScatterD,
|
||||
PermuteDLayout
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
|
||||
@ -326,7 +330,6 @@ struct DefaultEpilogueVoltaTensorOpAffineRankN {
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -49,6 +49,8 @@
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -67,7 +69,8 @@ template <
|
||||
typename ElementVector,
|
||||
typename OutputOp,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
bool ScatterD = false,
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
struct DefaultEpilogueWithBroadcastTensorOp {
|
||||
|
||||
@ -86,7 +89,8 @@ struct DefaultEpilogueWithBroadcastTensorOp {
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
ScatterD,
|
||||
PermuteDLayout
|
||||
>;
|
||||
|
||||
//
|
||||
|
||||
@ -50,6 +50,8 @@
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -67,7 +69,8 @@ template <
|
||||
typename OutputOp,
|
||||
typename ReductionOp,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
bool ScatterD = false,
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
struct DefaultEpilogueWithReductionTensorOp {
|
||||
|
||||
@ -89,7 +92,8 @@ struct DefaultEpilogueWithReductionTensorOp {
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
ScatterD,
|
||||
PermuteDLayout
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -120,7 +124,8 @@ template <
|
||||
typename OutputOp,
|
||||
typename ReductionOp,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
bool ScatterD = false,
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
struct DefaultEpilogueWithReductionVoltaTensorOp {
|
||||
|
||||
@ -142,7 +147,8 @@ struct DefaultEpilogueWithReductionVoltaTensorOp {
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
ScatterD,
|
||||
PermuteDLayout
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
|
||||
@ -64,6 +64,8 @@
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -79,7 +81,8 @@ template <
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
bool ScatterD = false,
|
||||
typename PermuteDLayout = layout::NoPermute
|
||||
>
|
||||
struct DefaultEpilogueWmmaTensorOp {
|
||||
|
||||
@ -109,7 +112,8 @@ struct DefaultEpilogueWmmaTensorOp {
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
ScatterD,
|
||||
PermuteDLayout
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp<
|
||||
|
||||
@ -0,0 +1,513 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue.
|
||||
|
||||
The epilogue finds max values in each row of the row-major output matrix and stores them.
|
||||
The max values are also used for a further round of threadblock scoped reduction operation, where
|
||||
the partial reduction results are stored in a pre-allocated array and used for further full reduction.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename OutputTileIterator_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementNorm_,
|
||||
typename ElementSum_,
|
||||
typename ElementSoftmaxCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool UseMasking_ = false
|
||||
>
|
||||
class EpilogueVisitorSoftmax {
|
||||
public:
|
||||
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using ElementNorm = ElementNorm_;
|
||||
using ElementSum = ElementSum_;
|
||||
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
||||
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using SoftmaxFragment = Array<ElementSoftmaxCompute, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
static bool const kUseMasking = UseMasking_;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Max;
|
||||
int64_t batch_stride_Sum;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments():
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
batch_stride_Max(0),
|
||||
batch_stride_Sum(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_
|
||||
):
|
||||
elementwise(elementwise_),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
batch_stride_Max(0),
|
||||
batch_stride_Sum(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
Arguments(
|
||||
typename ElementwiseFunctor::Params elementwise_,
|
||||
int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_,
|
||||
int64_t batch_stride_Max_,
|
||||
int64_t batch_stride_Sum_
|
||||
):
|
||||
elementwise(elementwise_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
batch_stride_D(batch_stride_D_),
|
||||
batch_stride_Max(batch_stride_Max_),
|
||||
batch_stride_Sum(batch_stride_Sum_)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct Params {
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_Max;
|
||||
int64_t batch_stride_Sum;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const &args):
|
||||
elementwise(args.elementwise),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
batch_stride_Max(args.batch_stride_Max),
|
||||
batch_stride_Sum(args.batch_stride_Sum)
|
||||
{
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage {
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
Params const & params_;
|
||||
SharedStorage & shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator alpha_;
|
||||
ElementAccumulator beta_;
|
||||
|
||||
ElementNorm *ptr_Max_;
|
||||
ElementSum *ptr_Sum_;
|
||||
|
||||
int column_offset_;
|
||||
|
||||
ElementSoftmaxCompute accum_max_;
|
||||
ElementSoftmaxCompute accum_sum_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
float infinity_;
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorSoftmax(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
cutlass::MatrixCoord const &problem_size,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx,
|
||||
typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D,
|
||||
typename OutputTileIterator::Element *ptr_C,
|
||||
typename OutputTileIterator::Element *ptr_D,
|
||||
ElementNorm *ptr_Max = nullptr,
|
||||
ElementSum *ptr_Sum = nullptr,
|
||||
cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0),
|
||||
int column_offset = 0,
|
||||
cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0),
|
||||
float infinity = 10000.0f
|
||||
):
|
||||
params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
extent_(problem_size),
|
||||
elementwise_(params.elementwise),
|
||||
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
||||
ptr_Max_(ptr_Max),
|
||||
ptr_Sum_(ptr_Sum),
|
||||
column_offset_(column_offset),
|
||||
extent_real_(problem_size_real),
|
||||
infinity_(infinity)
|
||||
{
|
||||
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator()) {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
fragment_C_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
// Clear accumulators for max and sum when starting a whole row
|
||||
clear_accum_();
|
||||
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorFragment const &accum) {
|
||||
|
||||
using Mul = cutlass::multiplies<SoftmaxFragment>;
|
||||
using Minus = cutlass::minus<SoftmaxFragment>;
|
||||
using Exp = cutlass::fast_exp_op<SoftmaxFragment>;
|
||||
|
||||
Minus minus;
|
||||
Exp exponential;
|
||||
|
||||
SoftmaxFragment result;
|
||||
|
||||
NumericArrayConverter<ElementSoftmaxCompute, ElementOutput, kElementsPerAccess> source_converter;
|
||||
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
|
||||
|
||||
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
result = source_converter(elementwise_(accum));
|
||||
}else{
|
||||
result = source_converter(elementwise_(accum, source_vector));
|
||||
}
|
||||
|
||||
thread_offset_ =
|
||||
iterator_D_.thread_start() +
|
||||
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
|
||||
|
||||
bool column_guard = (thread_offset_.column() < extent_.column());
|
||||
|
||||
if (kUseMasking) {
|
||||
int elements_in_boundary = extent_real_.column() - thread_offset_.column();
|
||||
elements_in_boundary = (elements_in_boundary > kElementsPerAccess) ? kElementsPerAccess : elements_in_boundary;
|
||||
elementwise_padding_(result, elements_in_boundary);
|
||||
}
|
||||
|
||||
ElementSoftmaxCompute accum_max_prev = accum_max_;
|
||||
|
||||
// Compute the maximum within one row
|
||||
if (!column_idx) {
|
||||
// This is the first fragment in a new row
|
||||
if (column_guard) {
|
||||
accum_max_ = maximum_accumulator_(result);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// This is an additional fragment in the same row
|
||||
if (column_guard) {
|
||||
accum_max_ = maximum_accumulator_(result, accum_max_);
|
||||
}
|
||||
}
|
||||
|
||||
// proactively compute max in warps
|
||||
accum_max_ = warp_reduce_max_(accum_max_);
|
||||
|
||||
ElementSoftmaxCompute updater = fast_exp(accum_max_prev - accum_max_);
|
||||
|
||||
SoftmaxFragment intermediate = exponential(minus(result, accum_max_));
|
||||
|
||||
if (kHasMultiStepsInRow) {
|
||||
if (!column_idx) {
|
||||
accum_sum_ = (column_guard) ? \
|
||||
sum_accumulator_(intermediate) : ElementSoftmaxCompute(0);
|
||||
} else {
|
||||
// Algorithm in $3.1, https://arxiv.org/pdf/2205.14135v1.pdf
|
||||
// S* = S* x updater + sum_row(P'), where updater = exp(M* - M_row)
|
||||
accum_sum_ = (column_guard) ? \
|
||||
sum_accumulator_(intermediate, accum_sum_ * updater) : accum_sum_ * updater;
|
||||
}
|
||||
} else {
|
||||
accum_sum_ = (column_guard) ? sum_accumulator_(intermediate, accum_sum_) : ElementSoftmaxCompute(0);
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementSoftmaxCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
|
||||
using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
|
||||
using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
|
||||
|
||||
ConvertSumOutput convert_sum_output;
|
||||
ConvertNormOutput convert_norm_output;
|
||||
|
||||
// Compute accumulate sum only in the last step
|
||||
accum_sum_ = warp_reduce_sum_(accum_sum_);
|
||||
|
||||
bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0);
|
||||
bool row_guard = thread_offset_.row() < extent_.row();
|
||||
bool is_write_thread = row_guard && is_first_thread_in_tile;
|
||||
|
||||
int block_batch = blockIdx.z;
|
||||
|
||||
ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Max;
|
||||
ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Sum;
|
||||
|
||||
arch::global_store<ElementNorm, sizeof(ElementNorm)>(
|
||||
convert_norm_output(accum_max_),
|
||||
(void *)curr_ptr_max,
|
||||
is_write_thread);
|
||||
|
||||
arch::global_store<ElementSum, sizeof(ElementSum)>(
|
||||
convert_sum_output(accum_sum_),
|
||||
(void *)curr_ptr_sum,
|
||||
is_write_thread);
|
||||
|
||||
// Clear accumulators for max and sum when finishing a whole row
|
||||
clear_accum_();
|
||||
|
||||
}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void elementwise_padding_(SoftmaxFragment &result, int elements_in_boundary) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
||||
result[i] = (i < elements_in_boundary) ? result[i] : ElementSoftmaxCompute(-infinity_);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute warp_reduce_sum_(ElementSoftmaxCompute sum_) {
|
||||
int half_thread_in_row = (kThreadsPerRow >> 1);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = half_thread_in_row; i > 0; i >>= 1) {
|
||||
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, sum_, i);
|
||||
sum_ += tmp;
|
||||
}
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute warp_reduce_max_(ElementSoftmaxCompute max_) {
|
||||
int half_thread_in_row = (kThreadsPerRow >> 1);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = half_thread_in_row; i > 0; i >>= 1) {
|
||||
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, max_, i);
|
||||
max_ = fast_max(max_, tmp);
|
||||
}
|
||||
return max_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void clear_accum_() {
|
||||
|
||||
uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX
|
||||
float min_float = reinterpret_cast<float const &>(float_max_bits);
|
||||
accum_max_ = ElementSoftmaxCompute(min_float);
|
||||
accum_sum_ = ElementSoftmaxCompute(0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) {
|
||||
ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
||||
sum_ += ElementSoftmaxCompute(accum[i]);
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute sum_) {
|
||||
// ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
||||
sum_ += ElementSoftmaxCompute(accum[i]);
|
||||
}
|
||||
|
||||
return sum_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) {
|
||||
ElementSoftmaxCompute max_ = accum[0];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < SoftmaxFragment::kElements; ++i) {
|
||||
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
|
||||
}
|
||||
|
||||
return max_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
|
||||
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
|
||||
}
|
||||
|
||||
return max_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@ -39,11 +39,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#include <cuda/std/utility>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#include <utility>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@ -121,6 +121,7 @@ public:
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(
|
||||
int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
@ -128,7 +129,7 @@ public:
|
||||
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {
|
||||
|
||||
@ -325,6 +326,7 @@ public:
|
||||
}
|
||||
|
||||
visitor.visit(
|
||||
iter_idx,
|
||||
row_idx,
|
||||
col_idx,
|
||||
idx,
|
||||
@ -391,10 +391,10 @@ struct OutputTileOptimalThreadMap {
|
||||
1>;
|
||||
|
||||
/// Initial offset function
|
||||
CUTLASS_HOST_DEVICE
|
||||
CUTLASS_DEVICE
|
||||
static MatrixCoord initial_offset(int thread_idx) {
|
||||
|
||||
int warp_idx = thread_idx / kWarpSize;
|
||||
int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0);
|
||||
int lane_idx = thread_idx % kWarpSize;
|
||||
|
||||
// Compute warp location
|
||||
@ -419,7 +419,7 @@ struct OutputTileOptimalThreadMap {
|
||||
|
||||
return MatrixCoord(
|
||||
cluster_offset + group_offset + row_offset + lane_row_offset,
|
||||
(column_offset + lane_col_offset) * kElementsPerAccess
|
||||
column_offset + lane_col_offset * kElementsPerAccess
|
||||
);
|
||||
}
|
||||
|
||||
@ -461,10 +461,10 @@ struct OutputTileOptimalThreadMap {
|
||||
static int const kThreads = Threads;
|
||||
|
||||
/// Function to compute each thread's initial offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
CUTLASS_DEVICE
|
||||
static MatrixCoord initial_offset(int thread_idx) {
|
||||
|
||||
int warp_idx = thread_idx / kWarpSize;
|
||||
int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0);
|
||||
int lane_idx = thread_idx % kWarpSize;
|
||||
|
||||
// Compute warp location
|
||||
@ -489,7 +489,7 @@ struct OutputTileOptimalThreadMap {
|
||||
|
||||
MatrixCoord coord(
|
||||
cluster_offset + group_offset + row_offset + lane_row_offset,
|
||||
(column_offset + lane_col_offset) * kElementsPerAccess
|
||||
column_offset + lane_col_offset * kElementsPerAccess
|
||||
);
|
||||
|
||||
return coord;
|
||||
|
||||
@ -43,6 +43,7 @@
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/permute.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
@ -70,6 +71,7 @@ template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
bool ScatterD = false, ///< Scatter D operand or not
|
||||
typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not
|
||||
bool UseCUDAStore = false
|
||||
>
|
||||
class PredicatedTileIterator {
|
||||
@ -173,9 +175,12 @@ private:
|
||||
/// Parameters structure containing reference and precomputed state.
|
||||
PredicatedTileIteratorParams params_;
|
||||
|
||||
/// Byte-level pointer
|
||||
/// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load().
|
||||
uint8_t *byte_pointer_;
|
||||
|
||||
/// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_.
|
||||
uint8_t *store_byte_pointer_;
|
||||
|
||||
/// Array of boolean values to contain steady-state predicates
|
||||
Mask mask_;
|
||||
|
||||
@ -196,6 +201,11 @@ private:
|
||||
|
||||
/// Scatter indices
|
||||
int const *indices_;
|
||||
|
||||
/// Whether to perform Permute Op
|
||||
bool PermuteD;
|
||||
/// PermuteDLayout
|
||||
mutable PermuteDLayout permute_layout_;
|
||||
|
||||
//
|
||||
// Static asserts about internal strides
|
||||
@ -255,7 +265,7 @@ public:
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
// Initialize pointer
|
||||
// Initialize byte_pointer_
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
@ -265,6 +275,19 @@ public:
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
}
|
||||
|
||||
// store_byte_pointer_ is set to be the same with byte_pointer_ unless PermuteD is used.
|
||||
store_byte_pointer_ = byte_pointer_;
|
||||
|
||||
// Initialize PermuteD. If PermuteD is true, store_byte_pointer_ is initialized accordingly.
|
||||
if (platform::is_same<PermuteDLayout, layout::NoPermute>::value) {
|
||||
PermuteD = false;
|
||||
}else{
|
||||
PermuteD = true;
|
||||
store_byte_pointer_ = reinterpret_cast<uint8_t *>(pointer);
|
||||
permute_layout_ = PermuteDLayout(extent,
|
||||
params_.stride * kElementsPerAccess / sizeof(AccessType));
|
||||
}
|
||||
|
||||
// Initialize internal state counter
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
@ -272,6 +295,7 @@ public:
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
store_byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
@ -353,7 +377,7 @@ public:
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const {
|
||||
uint8_t *byte_pointer = byte_pointer_;
|
||||
uint8_t *byte_pointer = store_byte_pointer_;
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -388,21 +412,38 @@ public:
|
||||
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
int col_offset = column * ThreadMap::Delta::kColumn;
|
||||
|
||||
if (PermuteD) {
|
||||
int col = col_offset + thread_start_column_;
|
||||
int row = row_offset + thread_start_row_;
|
||||
|
||||
TensorCoord init_coord(row, col);
|
||||
|
||||
// Locate memory_pointer
|
||||
memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset
|
||||
+ permute_layout_(init_coord) * sizeof(AccessType) / kElementsPerAccess);
|
||||
}
|
||||
|
||||
if (UseCUDAStore) {
|
||||
if (guard) {
|
||||
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
|
||||
memory_pointer[0] =
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
|
||||
}
|
||||
} else {
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
(void *)&memory_pointer[0],
|
||||
guard);
|
||||
}
|
||||
|
||||
if (!PermuteD) {
|
||||
memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess);
|
||||
}
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
if (!ScatterD) {
|
||||
if (!ScatterD && !PermuteD) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
@ -605,6 +646,10 @@ public:
|
||||
|
||||
++state_[0];
|
||||
|
||||
if (!ScatterD && !PermuteD) {
|
||||
store_byte_pointer_ += params_.advance_row;
|
||||
}
|
||||
|
||||
if (!ScatterD) {
|
||||
byte_pointer_ += params_.advance_row;
|
||||
}
|
||||
@ -616,6 +661,7 @@ public:
|
||||
state_[0] = 0;
|
||||
++state_[1];
|
||||
byte_pointer_ += params_.advance_group;
|
||||
store_byte_pointer_ += params_.advance_group;
|
||||
|
||||
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
|
||||
@ -625,6 +671,7 @@ public:
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
byte_pointer_ += params_.advance_cluster;
|
||||
store_byte_pointer_ += params_.advance_group;
|
||||
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
@ -632,6 +679,7 @@ public:
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
byte_pointer_ += params_.advance_tile;
|
||||
store_byte_pointer_ += params_.advance_group;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,6 +35,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -908,3 +908,4 @@ T absolute_value(T x) {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
478
include/cutlass/gemm/device/base_grouped.h
Normal file
478
include/cutlass/gemm/device/base_grouped.h
Normal file
@ -0,0 +1,478 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Base device-level grouped kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename BaseKernel_>
|
||||
class BaseGrouped {
|
||||
public:
|
||||
|
||||
using BaseKernel = BaseKernel_;
|
||||
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename BaseKernel::Arguments;
|
||||
|
||||
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
|
||||
|
||||
protected:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename BaseKernel::Params params_;
|
||||
|
||||
private:
|
||||
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) {
|
||||
int32_t tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i) {
|
||||
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
|
||||
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
|
||||
tiles += problem_tile_count(problem);
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
/// Copy from `data` to `workspace`
|
||||
Status copy_to_workspace(void* workspace, void* data, size_t bytes) {
|
||||
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
|
||||
if (cuda_error != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
cuda_error = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaMemcpy() returned error "
|
||||
<< cudaGetErrorString(cuda_error));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Precomputes scheduling information for the grouped GEMM
|
||||
Status precompute(Arguments const &args, int32_t tile_count, void* workspace) {
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
std::vector<uint8_t> host_workspace(workspace_bytes);
|
||||
BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes,
|
||||
args.problem_count,
|
||||
args.threadblock_count,
|
||||
(void*)host_workspace.data());
|
||||
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
|
||||
}
|
||||
|
||||
/// Reorder `data` according to `indices`
|
||||
template <typename T>
|
||||
static void reorder_array(T* data, const std::vector<size_t>& indices) {
|
||||
// For now, simply create a copy of the data and then copy over to the original.
|
||||
std::vector<T> copy(indices.size());
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
copy.at(i) = data[indices[i]];
|
||||
}
|
||||
|
||||
memcpy(data, copy.data(), indices.size() * sizeof(T));
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
BaseGrouped() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return BaseKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Get the number of tiles in a problem
|
||||
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const &problem) {
|
||||
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
|
||||
return BaseKernel::ProblemVisitor::tile_count(grid);
|
||||
}
|
||||
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(Arguments const &args) {
|
||||
if (args.host_problem_sizes == nullptr) {
|
||||
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return group_tile_count(args.host_problem_sizes, args.problem_count);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
|
||||
return BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes,
|
||||
args.problem_count,
|
||||
args.threadblock_count);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
|
||||
return dim3(args.threadblock_count, 1, 1);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<BaseKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel<BaseKernel>,
|
||||
BaseKernel::kThreadCount,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Sorts each pointer passed in according to the indices that sort
|
||||
/// `problem_sizes_ptr` in descending order of problem-K dimension.
|
||||
static void sort_problems(int problem_count,
|
||||
cutlass::gemm::GemmCoord* problem_sizes_ptr,
|
||||
int64_t* lda_host_ptr,
|
||||
int64_t* ldb_host_ptr,
|
||||
int64_t* ldc_host_ptr,
|
||||
int64_t* ldd_host_ptr,
|
||||
int64_t* offset_A_ptr,
|
||||
int64_t* offset_B_ptr,
|
||||
int64_t* offset_C_ptr,
|
||||
int64_t* offset_D_ptr)
|
||||
{
|
||||
std::vector<size_t> indices(problem_count);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::stable_sort(indices.begin(), indices.end(),
|
||||
[&problem_sizes_ptr](size_t i, size_t j) {
|
||||
return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k();
|
||||
});
|
||||
|
||||
reorder_array(problem_sizes_ptr, indices);
|
||||
reorder_array(lda_host_ptr, indices);
|
||||
reorder_array(ldb_host_ptr, indices);
|
||||
reorder_array(ldc_host_ptr, indices);
|
||||
reorder_array(ldd_host_ptr, indices);
|
||||
reorder_array(offset_A_ptr, indices);
|
||||
reorder_array(offset_B_ptr, indices);
|
||||
reorder_array(offset_C_ptr, indices);
|
||||
reorder_array(offset_D_ptr, indices);
|
||||
}
|
||||
|
||||
/// Computes the number of threadblocks to launch for the grouped kernel
|
||||
static int sufficient(const cutlass::gemm::GemmCoord* problem_sizes_ptr=nullptr,
|
||||
int problem_count=0,
|
||||
int available_sm_count=-1) {
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
if (result != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
if (result != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool override_sm_count = (available_sm_count < 0 || available_sm_count > properties.multiProcessorCount);
|
||||
if (override_sm_count) {
|
||||
available_sm_count = properties.multiProcessorCount;
|
||||
}
|
||||
|
||||
int max_active_blocks = maximum_active_blocks();
|
||||
if (max_active_blocks <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int occupancy_based_block_count = available_sm_count * max_active_blocks;
|
||||
|
||||
if (problem_sizes_ptr == nullptr || problem_count == 0) {
|
||||
return occupancy_based_block_count;
|
||||
}
|
||||
|
||||
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
|
||||
|
||||
// If the group contains a single problem, launching the exact number of
|
||||
// threadblocks needed to cover the problem minimizes the work performed
|
||||
// per threadblock in finding the next tile to compute. We return total_tiles
|
||||
// unless the user has provided the SM count.
|
||||
if (problem_count == 1 && override_sm_count) {
|
||||
return total_tiles;
|
||||
}
|
||||
|
||||
// Choose between the full wave of threadblocks and the tile count. If there
|
||||
// are fewer tiles in the group than threadblocks in the full wave, only
|
||||
// some threadblocks will be assigned tiles. Those threadblocks
|
||||
// which are not assigned tiles still need to perform the work of iterating through
|
||||
// problem sizes to determine that they have no work to do. This competes for cycles
|
||||
// with those threadblocks that are assigned tiles to compute.
|
||||
return min(total_tiles, occupancy_based_block_count);
|
||||
}
|
||||
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Workspace
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
params_ = typename BaseKernel::Params(args, workspace, tile_count);
|
||||
} else {
|
||||
params_ = typename BaseKernel::Params(args, workspace);
|
||||
}
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(Kernel<BaseKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
params_.update(args, workspace, tile_count);
|
||||
} else {
|
||||
params_.update(args, workspace);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
if (!params_.problem_visitor.problem_count) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
dim3 grid(params_.threadblock_count, 1, 1);
|
||||
dim3 block(BaseKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Initializes and runs the kernel.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user