CUTLASS 2.10 (#615)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2022-09-03 15:48:46 -07:00
committed by GitHub
parent ca23ff7924
commit b72cbf957d
289 changed files with 43708 additions and 2513 deletions

View File

@ -20,4 +20,4 @@ A clear and concise description of what you expected to happen.
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
**Additional context**
Add any other context about the problem here.
Add any other context about the problem here.

View File

@ -32,4 +32,4 @@ A clear and concise description of what documentation you believe it is needed a
A clear and concise description of what you want to happen.
**Steps taken to search for needed documentation**
List any steps you have taken:
List any steps you have taken:

View File

@ -7,4 +7,4 @@ assignees: ''
---
**What is your question?**
**What is your question?**

View File

@ -8,4 +8,4 @@ jobs:
steps:
- uses: actions/labeler@main
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
repo-token: "${{ secrets.GITHUB_TOKEN }}"

View File

@ -32,4 +32,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'
GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing'

View File

@ -54,4 +54,4 @@ jobs:
exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue"
days-before-pr-stale: 90
days-before-pr-close: -1
operations-per-run: 50
operations-per-run: 50

View File

@ -1,5 +1,18 @@
# NVIDIA CUTLASS Changelog
## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23)
* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
* [Grouped GEMM for Multihead Attention](examples/50_multi_head_attention)
* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
* Updates and bugfixes from the community (thanks!)
* **Deprecation announcement:** CUTLASS plans to deprecate the following:
* Maxwell and Pascal GPU architectures
* Ubuntu 16.04
* CUDA 10.2
## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21)
* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
@ -37,6 +50,7 @@
* Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19)
* **TF32x3:** emulated single-precision using Tensor Cores

View File

@ -1,82 +0,0 @@
cff-version: 1.2.0
title: CUTLASS
message: >-
If you use this software, please cite using the
following metadata.
type: software
authors:
- given-names: Andrew
email: akerr@nvidia.com
family-names: Kerr
affiliation: NVIDIA
- given-names: Haicheng
family-names: Wu
affiliation: NVIDIA
email: haichengw@nvidia.com
- given-names: Manish
family-names: Gupta
affiliation: Google
email: manigupta@google.com
- given-names: Dustyn
family-names: Blasig
email: dblasig@nvidia.com
affiliation: NVIDIA
- given-names: Pradeep
family-names: Ramini
email: prramani@nvidia.com
affiliation: NVIDIA
- given-names: Duane
family-names: Merrill
email: dumerrill@nvidia.com
affiliation: NVIDIA
- given-names: Aniket
family-names: Shivam
email: ashivam@nvidia.com
affiliation: NVIDIA
- given-names: Piotr
family-names: Majcher
email: pmajcher@nvidia.com
affiliation: NVIDIA
- given-names: Paul
family-names: Springer
email: pspringer@nvidia.com
affiliation: NVIDIA
- given-names: Markus
family-names: Hohnerbach
affiliation: NVIDIA
email: mhohnerbach@nvidia.com
- given-names: Jin
family-names: Wang
email: jinw@nvidia.com
affiliation: NVIDIA
- given-names: Matt
family-names: Nicely
email: mnicely@nvidia.com
affiliation: NVIDIA
repository-code: 'https://github.com/NVIDIA/cutlass'
abstract: >-
CUTLASS is a collection of CUDA C++ template
abstractions for implementing high-performance
matrix-multiplication (GEMM) and related
computations at all levels and scales within CUDA.
It incorporates strategies for hierarchical
decomposition and data movement similar to those
used to implement cuBLAS and cuDNN. CUTLASS
decomposes these "moving parts" into reusable,
modular software components abstracted by C++
template classes. These thread-wide, warp-wide,
block-wide, and device-wide primitives can be
specialized and tuned via custom tiling sizes, data
types, and other algorithmic policy. The resulting
flexibility simplifies their use as building blocks
within custom kernels and applications.
keywords:
- 'cutlass, tensor cores, cuda'
license: BSD-3-Clause
license-url: https://github.com/NVIDIA/cutlass/blob/v2.9.0/LICENSE.txt
version: '2.9'
date-released: '2022-04-27'
identifiers:
- type: url
value: "https://github.com/NVIDIA/cutlass/tree/v2.9.0"
description: The GitHub release URL of tag 2.9.0

View File

@ -38,7 +38,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
project(CUTLASS VERSION 2.9.0 LANGUAGES CXX)
project(CUTLASS VERSION 2.10.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 10.2)

View File

@ -11,12 +11,19 @@ Andrew Kerr
Haicheng Wu
Manish Gupta
Dustyn Blasig
Pradeep Ramani
Pradeep Ramani
Cris Cecka
Vijay Thakkar
Aniket Shivam
Honghao Lu
Ethan Yan
Zhaodong Chen
Jack Kosaian
Yujia Zhai
Naila Farooqui
Piotr Majcher
Paul Springer
Jin Wang
Aniket Shivam
Chinmay Talegaonkar
Shang Zhang
Scott Yokim
@ -53,7 +60,6 @@ Nick Zhao
## ACKNOWLEDGEMENTS
Girish Bharambe
Cris Cecka
Luke Durant
Olivier Giroux
Stephen Jones

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 2.9
# CUTLASS 2.10
_CUTLASS 2.9 - April 2022_
_CUTLASS 2.10 - August 2022_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) and related computations at all levels
@ -18,7 +18,9 @@ To support a wide variety of applications, CUTLASS provides extensive support fo
mixed-precision computations, providing specialized data-movement and
multiply-accumulate abstractions for half-precision floating
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
single-precision floating point (FP32), double-precision floating
single-precision floating point (FP32),
[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm),
double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
@ -34,26 +36,14 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](/media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
# What's New in CUTLASS 2.9
# What's New in CUTLASS 2.10
CUTLASS 2.9 is an update to CUTLASS adding:
- [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment
- [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores
- [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu), [HERK](/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
- [SYR2K](/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu), [HER2K](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu),
- [Out-of-place TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu), and
- [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu), [HEMM](/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu)
- [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python)
- [GEMM + Softmax example](/examples/35_gemm_softmax)
- [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel.
- [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. Bias Vector add is also supported in the first GEMM/CONV.
- [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation.
- [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC.
- [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores.
- Epilogue enhancement with performance improvement, more activation functions, and more fusion patterns.
- [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix.
- Optimal performance using [CUDA 11.7](https://developer.nvidia.com/cuda-downloads)
- [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler.
CUTLASS 2.10 is an update to CUTLASS adding:
- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu)
- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu)
- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel
- [Grouped GEMM for Multihead Attention](examples/50_multi_head_attention)
- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/)
- Updates and bugfixes from the community (thanks!)
- **Deprecation announcement:** CUTLASS plans to deprecate the following:
- Maxwell and Pascal GPU architectures
@ -249,15 +239,15 @@ examples/
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
13_fused_two_gemms/ # example demonstrating two GEMMs fused in one kernel
22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores
31_basic_syrk # example demonstrating Symetric rank-K update
31_basic_syrk # example demonstrating Symmetric Rank-K update
32_basic_trmm #
32_basic_trmm # example demonstrating Triangular Matrix-Matrix multiplication
33_ampere_3xtf32_tensorop_symm #
33_ampere_3xtf32_tensorop_symm # example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation
35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores

View File

@ -54,12 +54,11 @@ using ElementInputA = cutlass::half_t; // <- data type of elements
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices.
// Column Major for Matrix A, B and C.
// Note that if the output is column major, the bias has to be per row. i.e. every row has different bias.
// If the output is row major, the bias has to be per column, i.e. every column has different bias.
// Below list some other notices:
//
// Note this example only works for ColumnMajor output because
// 1) we only have row major epilogue.
// 2) we swap A and B if the output is column major then we can still use the
// row major epilogue.

View File

@ -457,9 +457,13 @@ Result profile_convolution(Options const &options) {
ElementInputB(-8),
0);
// Fill tensor C on host with zeros
cutlass::reference::host::TensorFill(
tensor_c.host_view());
// Fill tensor C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0);
// Fill tensor D on host with zeros
cutlass::reference::host::TensorFill(
@ -686,7 +690,7 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -290,7 +290,7 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -326,7 +326,7 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -32,7 +32,7 @@
/**
The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when
computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere
HMMA 16x8x16 FP16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor
core instructions.
Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h
@ -67,9 +67,9 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
using ElementInputA = cutlass::bfloat16_t; // Data type of elements in input tensor
using ElementInputB = cutlass::bfloat16_t; // Data type of elements in input tensor
using ElementOutput = cutlass::bfloat16_t; // Data type of elements in output tensor
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::RowMajor;
@ -369,22 +369,22 @@ Result profile(Options const &options) {
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
ElementInputA(2),
ElementInputA(-2),
0); // <- Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
ElementInputB(2),
ElementInputB(-2),
0); // <- Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
ElementOutput(2),
ElementOutput(-2),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
@ -612,10 +612,10 @@ Result profile(Options const &options) {
if (options.reference_check) {
output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n";
output_workspace << "Reference reduction vector= \n" << tensor_ref_reduction.host_view() << "\n\n";
output_workspace << "Reference reduction vector = \n" << tensor_ref_reduction.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
output_workspace << "Computed D = \n" << tensor_d.host_view() << std::endl;
output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
@ -699,7 +699,7 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

File diff suppressed because it is too large Load Diff

View File

@ -34,3 +34,8 @@ cutlass_example_add_executable(
ampere_fprop_mainloop_fusion.cu
)
cutlass_example_add_executable(
25_ampere_3d_fprop_mainloop_fusion
ampere_3d_fprop_mainloop_fusion.cu
)

View File

@ -0,0 +1,776 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to fuse per channel scale+bias+relu of the activations
into the 3D fprop mainloop.
Compared with original 3D fprop kernel, this example has two more vectors, one for
the scale and one for the bias. The length of the vectors is the same as the
activation channel number. This kernel loads the vectors when the associated
activation channels are loaded in the mainloop. Between reading the
activations and scale/bias data from the shared memory and calling tensor core
instructions, scale+bias+relu is computed in the register file.
This example is customized for Ampere 16816 fp16 tensor core instruction.
Changing to different data types or different tensor core instruction require
source code changing. See
include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more
technical details.
This example is modified based on 25_ampere_fprop_mainloop_fusion. The command
line is the same.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/conv/kernel/default_conv3d_fprop_fusion.h"
#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/convolution.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors
using ElementOutput = float; // Data type of elements in output tensor
using LayoutInputA = cutlass::layout::TensorNDHWC;
using LayoutInputB = cutlass::layout::TensorNDHWC;
using LayoutInputScaleBias = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::TensorNDHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Number of pipelines you want to use
constexpr int NumStages = 4;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
using Conv3dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv3dFpropFusion<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementInputScaleBias, LayoutInputScaleBias,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm
>::Kernel;
using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion<Conv3dFpropFusionKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::Tensor5DCoord input_size;
cutlass::Tensor5DCoord filter_size;
cutlass::Coord<3> padding;
cutlass::Coord<3> conv_stride;
cutlass::Coord<3> dilation;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool benchmark;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32, 32),
filter_size(32, 3, 3, 3, 32),
padding(cutlass::make_Coord(1, 1, 1)),
conv_stride(cutlass::make_Coord(1, 1, 1)),
dilation(cutlass::make_Coord(1, 1, 1)),
reference_check(true),
measure_performance(false),
iterations(20),
save_workspace(false),
alpha(1),
beta(0),
benchmark(false) { }
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
bool valid() {
//
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((input_size.c() % kAlignment) ||
(filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding[0] != filter_size.d() / 2) ||
(padding[1] != filter_size.h() / 2) ||
(padding[2] != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
cutlass::Tensor5DCoord input_size,
cutlass::Tensor5DCoord filter_size,
cutlass::Coord<3> stride) {
this->input_size = input_size;
this->filter_size = filter_size;
conv_stride = stride;
padding[0] = filter_size.d() / 2;
padding[1] = filter_size.h() / 2;
padding[2] = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("d", input_size.d());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("t", filter_size.d());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
if (filter_size.d() == 3 && filter_size.h() == 3 && filter_size.w() == 3) {
padding = cutlass::make_Coord(1, 1, 1);
}
else {
filter_size.d() = 1;
filter_size.h() = 1;
filter_size.w() = 1;
padding = cutlass::make_Coord(0, 0, 0);
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "25_ampere_3d_fprop_mainloop_fusion example\n\n"
<< " This example fuses scale+bias+relu of the activations into Ampere's\n"
<< " Tensor Core operators on F16 data types to compute\n"
<< " forward convolution on tensors of layout NDHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n <int> Input tensor extent N\n"
<< " --d <int> Input tensor extent D\n"
<< " --h <int> Input tensor extent H\n"
<< " --w <int> Input tensor extent W\n"
<< " --c <int> Input tensor extent C\n"
<< " --k <int> Filter extent K\n"
<< " --t <int> Filter extent T\n"
<< " --r <int> Filter extent R\n"
<< " --s <int> Filter extent S\n\n"
<< " --alpha <float> Epilogue scalar alpha\n"
<< " --beta <float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
<< " --iterations <int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag <string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./25_ampere_3d_fprop_mainloop_fusion --n=32 --d=96 --h=96 --w=96 --c=64 --k=64 --t=1 --r=1 --s=1\n\n"
<< "$ ./25_ampere_3d_fprop_mainloop_fusion --n=1 --d=224 --h=224 --w=224 --c=32 --k=32 --t=3 --r=3 --s=3 --ref-check\n\n"
<< "$ ./25_ampere_3d_fprop_mainloop_fusion --n=19 --d=94 --h=96 --w=96 --c=128 --k=128 --t=1 --r=1 --s=1\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
cutlass::Tensor5DCoord output_size() const {
return cutlass::Tensor5DCoord(
input_size.n(),
(input_size.d() + padding[0] + padding[0] - filter_size.d()) / conv_stride[0] + 1,
(input_size.h() + padding[1] + padding[1] - filter_size.h()) / conv_stride[1] + 1,
(input_size.w() + padding[2] + padding[2] - filter_size.w()) / conv_stride[2] + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.d() * filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cutlass::Status reference_check;
cudaError_t error;
Result():
runtime_ms(0),
gflops(0),
status(cutlass::Status::kSuccess),
reference_check(cutlass::Status::kInvalid),
error(cudaSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,D,H,W,C,K,T,R,S,Stride_D,Stride_H,Stride_W,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.d() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.d() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.conv_stride[0] << ","
<< options.conv_stride[1] << ","
<< options.conv_stride[2] << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the CUTLASS Utilities.
//
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_transformed_a(options.input_size);
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
tensor_a_scale({1, options.input_size.c()});
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
tensor_a_bias({1, options.input_size.c()});
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(3),
ElementInputA(-4),
0);
// Fill scale vector for tensor A on host with uniform-distribution random
// data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a_scale.host_view(),
1,
ElementInputA(3),
ElementInputA(-4),
0);
// Fill bias vector for tensor A on host with uniform-distribution random
// data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a_bias.host_view(),
1,
ElementInputA(3),
ElementInputA(-4),
0);
// Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(7),
ElementInputB(-8),
0);
// Fill tensor C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0);
// Fill tensor D for reference on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_a_scale.sync_device();
tensor_a_bias.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for CUTLASS Convolution
//
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Construct Conv3dProblemSize with user defined output size
cutlass::conv::Conv3dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices
);
typename ImplicitGemmFusion::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_a_scale.device_ref(),
tensor_a_bias.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{options.alpha, options.beta},
};
//
// Initialize CUTLASS Convolution
//
ImplicitGemmFusion implicit_gemm_fusion_op;
size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_fusion_op.can_implement(arguments);
CUTLASS_CHECK(result.status);
result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(result.status);
//
// Launch initialized CUTLASS kernel
//
result.status = implicit_gemm_fusion_op();
CUTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on device...\n";
// Compute scale + bias + relu in host code
for (int n = 0; n < options.input_size.n(); ++n) {
for (int d = 0; d < options.input_size.d(); ++d) {
for (int h = 0; h < options.input_size.h(); ++h) {
for (int w = 0; w < options.input_size.w(); ++w) {
for (int c = 0; c < options.input_size.c(); ++c) {
tensor_transformed_a.at({n, d, h, w, c}) = std::max(
ElementOutput(0), ElementOutput(tensor_a.at({n, d, h, w, c}) *
tensor_a_scale.at({0, c}) +
tensor_a_bias.at({0, c})));
}
}
}
}
}
tensor_transformed_a.sync_device();
// Compute with reference implementation
cutlass::reference::device::Conv3dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_transformed_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_ref_d.device_ref(),
options.alpha,
options.beta
);
// Check if output from CUTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
tensor_ref_d.sync_host();
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
if (!passed) {
result.reference_check = cutlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = cutlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
else {
result.reference_check = cutlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "25_ampere_3d_fprop_mainloop_fusion"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_fusion_op();
CUTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
bool notSupported = false;
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv3dFprop examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major >= 8)) {
std::cerr << "This test must run on SM80 or above.\n";
notSupported = true;
}
if (notSupported) {
return 0;
}
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
int batch_sizes[] = {34, 18};
struct Benchmark {
int d, h, w, c, k, t, r, s, stride_d, stride_h, stride_w;
} layers[] = {
{56, 56, 56, 64, 256, 1, 1, 1, 1, 1, 1},
{56, 56, 56, 64, 64, 1, 1, 1, 1, 1, 1},
{56, 56, 56, 64, 64, 3, 3, 3, 1, 1, 1},
{56, 56, 56, 256, 64, 1, 1, 1, 1, 1, 1},
{56, 56, 56, 256, 512, 1, 1, 1, 2, 2, 2},
{56, 56, 56, 256, 128, 1, 1, 1, 1, 1, 1},
{56, 56, 56, 128, 128, 3, 3, 3, 2, 2, 2},
{28, 28, 28, 128, 512, 1, 1, 1, 1, 1, 1},
{28, 28, 28, 512, 128, 1, 1, 1, 1, 1, 1},
{28, 28, 28, 128, 128, 3, 3, 3, 1, 1, 1},
{28, 28, 28, 512, 1024, 1, 1, 1, 2, 2, 2},
{28, 28, 28, 512, 256, 1, 1, 1, 1, 1, 1},
{28, 28, 28, 256, 256, 3, 3, 3, 2, 2, 2},
{14, 14, 14, 256, 1024, 1, 1, 1, 1, 1, 1},
{14, 14, 14, 1024, 256, 1, 1, 1, 1, 1, 1},
{14, 14, 14, 256, 256, 3, 3, 3, 1, 1, 1},
{14, 14, 14, 1024, 2048, 1, 1, 1, 2, 2, 2},
{14, 14, 14, 1024, 512, 1, 1, 1, 1, 1, 1},
{14, 14, 14, 512, 512, 3, 3, 3, 2, 2, 2},
{ 7, 7, 7, 512, 2048, 1, 1, 1, 1, 1, 1},
{ 7, 7, 7, 2048, 512, 1, 1, 1, 1, 1, 1},
{ 7, 7, 7, 512, 512, 3, 3, 3, 1, 1, 1},
};
Result::print_header(std::cout, options) << std::endl;
int idx = 1;
for (auto const &layer : layers) {
for (auto N : batch_sizes) {
options.update({N, layer.d, layer.h, layer.w, layer.c},
{layer.k, layer.t, layer.r, layer.s, layer.c},
cutlass::make_Coord(layer.stride_d, layer.stride_h, layer.stride_w));
Result result = profile_convolution(options);
result.print(std::cout, idx, options) << std::endl;
}
++idx;
}
}
else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -429,9 +429,13 @@ Result profile_convolution(Options const &options) {
ElementInputB(-8),
0);
// Fill tensor C on host with zeros
cutlass::reference::host::TensorFill(
tensor_c.host_view());
// Fill tensor C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0);
// Fill tensor D on host with zeros
cutlass::reference::host::TensorFill(
@ -575,7 +579,7 @@ Result profile_convolution(Options const &options) {
std::stringstream ss;
ss << "25_ampere_fprop_mainloop_fusion_"
ss << "25_ampere_fprop_mainloop_fusion"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
@ -677,8 +681,8 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major == 8 && props.minor == 0)) {
std::cerr << "This test must run on SM80 A100.\n";
if (!(props.major >= 8)) {
std::cerr << "This test must run on SM80 or above.\n";
notSupported = true;
}

View File

@ -266,8 +266,8 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "26_ampere_fused_wgrad_batch_normalization example\n\n"
<< " This example fuses scale+bias+relu from batch norm into Ampere's\n"
out << "26_ampere_wgrad_mainloop_fusion example\n\n"
<< " This example fuses scale+bias+relu of the activation into Ampere's\n"
<< " Tensor Core operators on F16 data types to compute\n"
<< " backward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
@ -289,8 +289,8 @@ struct Options {
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
<< "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
<< "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
<< "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
return out;
}
@ -427,9 +427,13 @@ Result profile_convolution(Options const &options) {
ElementInputA(-4),
0);
// Fill tensor C on host with zeros
cutlass::reference::host::TensorFill(
tensor_c.host_view());
// Fill tensor C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0);
// Fill tensor D on host with zeros
cutlass::reference::host::TensorFill(

View File

@ -740,7 +740,7 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -703,7 +703,7 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -603,7 +603,7 @@ int main(int argc, char const **args) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -47,14 +47,17 @@
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/host/tensor_reduce.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/epilogue/thread/linear_combination.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -85,18 +88,18 @@ struct Options {
float alpha;
float beta;
bool verification_enabled;
double tolerance;
float tolerance;
Options():
help(false),
problem_size({16, 24, 64}),
batch_count(1), // As a temporary limitation to the test bench, batch count must be 1. The kernels support arbitrary batching.
batch_count(16),
iterations(20),
seed(2022),
alpha(1),
beta(),
beta(0),
verification_enabled(true),
tolerance(0.01)
tolerance(1e-5f)
{ }
bool valid() {
@ -116,6 +119,8 @@ struct Options {
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("batch_count", batch_count);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
@ -135,6 +140,7 @@ struct Options {
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --batch_count=<int> Batch number\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --seed=<int> Random number seed (1*)\n\n"
@ -198,13 +204,22 @@ struct Testbed {
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementD = cutlass::half_t;
using ElementCompute = float;
using ElementSoftmax = cutlass::half_t;
using ElementD = ElementC;
using ElementSoftmax = ElementC;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80;
static int const kStages = 3;
/// Linear scaling operator
using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination<
ElementC,
@ -218,12 +233,21 @@ struct Testbed {
ElementB, LayoutB,
ElementC,
ElementCompute,
EpilogueFunctorOp
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueFunctorOp,
kStages
>;
using ElementNorm = typename GemmSoftmax::ElementNorm;
using ElementSum = typename GemmSoftmax::ElementSum;
using LayoutC = typename GemmSoftmax::LayoutC;
using LayoutN = typename GemmSoftmax::LayoutN;
using LayoutS = typename GemmSoftmax::LayoutS;
using MatrixCoord = typename LayoutC::TensorCoord;
//
// Data members
@ -231,20 +255,42 @@ struct Testbed {
Options const &options;
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<ElementD, LayoutC> tensor_D;
cutlass::HostTensor<ElementNorm, LayoutC> tensor_N;
cutlass::HostTensor<ElementSum, LayoutC> tensor_S;
cutlass::HostTensor<ElementSoftmax, LayoutC> tensor_Softmax;
cutlass::HostTensor<ElementD, LayoutC> reference_D;
cutlass::HostTensor<ElementNorm, LayoutC> reference_N;
cutlass::HostTensor<ElementSoftmax, LayoutC> reference_Softmax;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementD> block_D;
cutlass::DeviceAllocation<ElementD> block_Ref;
cutlass::DeviceAllocation<ElementSoftmax> block_Softmax;
cutlass::DeviceAllocation<ElementNorm> block_Norm;
cutlass::DeviceAllocation<ElementSum> block_Sum;
int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN;
cutlass::gemm::GemmCoord problem = options.problem_size;
int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0);
int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0);
int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0);
// fixed rowmajor for norm and sum
int64_t ldn = problem.m();
int64_t lds = ldn;
int64_t total_elements_A_per_batch = problem.m() * problem.k();
int64_t total_elements_B_per_batch = problem.k() * problem.n();
int64_t total_elements_C_per_batch = problem.m() * problem.n();
int64_t total_elements_D_per_batch = problem.m() * problem.n();
int64_t total_elements_partial_norm_per_batch = block_num * problem.m();
int64_t total_elements_A = total_elements_A_per_batch * options.batch_count;
int64_t total_elements_B = total_elements_B_per_batch * options.batch_count;
int64_t total_elements_C = total_elements_C_per_batch * options.batch_count;
int64_t total_elements_D = total_elements_D_per_batch * options.batch_count;
int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count;
//
// Methods
//
@ -254,20 +300,7 @@ struct Testbed {
):
options(options_)
{
tensor_A.reset({options.problem_size.m(), options.problem_size.k()});
tensor_B.reset({options.problem_size.k(), options.problem_size.n()});
tensor_C.reset({options.problem_size.m(), options.problem_size.n()});
tensor_D.reset({options.problem_size.m(), options.problem_size.n()});
tensor_N.reset({block_num, options.problem_size.m()});
tensor_S.reset({block_num, options.problem_size.m()});
tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()});
reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false);
reference_N.reset({options.problem_size.m(), 1}, false);
reference_Softmax.reset({options.problem_size.m(), options.problem_size.n()}, false);
}
/// Run
@ -300,11 +333,6 @@ struct Testbed {
return disposition;
}
//
// Compute the reference
//
compute_reference();
//
// Verify
//
@ -334,43 +362,38 @@ struct Testbed {
/// Random initialization
void initialize() {
cutlass::reference::host::TensorFillRandomUniform(
tensor_A.host_view(),
options.seed,
ElementD(5),
ElementD(-5),
0
);
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_Softmax.reset(total_elements_D);
block_Ref.reset(total_elements_D_per_batch);
block_Norm.reset(total_elements_partial_norm);
block_Sum.reset(total_elements_partial_norm);
cutlass::reference::host::TensorFillRandomUniform(
tensor_B.host_view(),
options.seed + 19,
ElementD(5),
ElementD(-5),
0
);
cutlass::reference::device::BlockFillRandomUniform(
block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0);
cutlass::reference::host::TensorFill(
reference_D.host_view(),
ElementD()
);
cutlass::reference::device::BlockFillRandomUniform(
block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0);
cutlass::reference::device::BlockFillRandomUniform(
block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0);
cutlass::reference::device::BlockFillRandomUniform(
block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0);
cutlass::reference::device::BlockFillRandomUniform(
block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0);
cutlass::reference::device::BlockFillRandomUniform(
block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0);
cutlass::reference::host::TensorFill(
reference_N.host_view(),
ElementNorm()
);
cutlass::reference::host::TensorFill(
reference_Softmax.host_view(),
ElementSoftmax()
);
tensor_A.sync_device();
tensor_B.sync_device();
tensor_D.sync_device();
tensor_N.sync_device();
tensor_S.sync_device();
tensor_Softmax.sync_device();
}
cutlass::Status execute_device_kernel() {
@ -384,17 +407,24 @@ struct Testbed {
GemmSoftmax::Arguments args(
options.problem_size,
options.batch_count,
tensor_A.device_ref(),
tensor_B.device_ref(),
tensor_C.device_ref(),
tensor_D.device_ref(),
{block_A.get(), lda},
{block_B.get(), ldb},
{block_C.get(), ldc},
{block_D.get(), ldc},
{
ElementCompute(options.alpha),
ElementCompute(options.beta)
},
tensor_N.device_ref(),
tensor_S.device_ref(),
tensor_Softmax.device_ref()
{block_Norm.get(), ldn},
{block_Sum.get(), lds},
{block_Softmax.get(), ldc},
total_elements_A_per_batch,
total_elements_B_per_batch,
total_elements_C_per_batch,
total_elements_D_per_batch,
total_elements_partial_norm_per_batch,
total_elements_partial_norm_per_batch,
total_elements_D_per_batch
);
//
@ -415,68 +445,21 @@ struct Testbed {
return status;
}
/// Reference calculation
void compute_reference() {
template<typename Element>
bool verify_tensor(std::vector<Element> vector_Input, \
std::vector<Element> vector_Input_Ref) {
// Compute GEMM
cutlass::reference::host::GemmComplex(
options.problem_size,
options.alpha,
tensor_A.host_ref(),
cutlass::ComplexTransform::kNone,
tensor_B.host_ref(),
cutlass::ComplexTransform::kNone,
options.beta,
tensor_C.host_ref(),
reference_D.host_ref(),
double()
);
// Compute the norm
for (int m = 0; m < options.problem_size.m(); ++m) {
reference_N.at({m, 0}) = reference_D.at({m, 0});
for (int n = 1; n < options.problem_size.n(); ++n) {
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n})));
}
}
// Compute softmax
for (int m = 0; m < options.problem_size.m(); ++m) {
float sum = float();
for (int n = 0; n < options.problem_size.n(); ++n) {
sum += std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) );
}
float inv_sum = float(1.0f / sum);
for (int n = 0; n < options.problem_size.n(); ++n) {
reference_Softmax.at({m, n}) = ElementSoftmax(
std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum
);
}
}
}
/// Emits all tensor values
void emit_results() {
std::cout << "D = \n" << tensor_D.host_view() << "\n\n";
std::cout << "N = \n" << tensor_N.host_view() << "\n\n";
std::cout << "Softmax = \n" << tensor_Softmax.host_view() << "\n\n";
std::cout << "Reference N = \n" << reference_N.host_view() << "\n\n";
std::cout << "Reference D = \n" << reference_D.host_view() << "\n\n";
std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n";
}
bool verify_tensor_N(cutlass::HostTensor<ElementNorm, LayoutC> tensor_N, \
cutlass::HostTensor<ElementNorm, LayoutC> reference_N) {
for (int m = 0; m < options.problem_size.m(); ++m) {
float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0}));
if (fabs(diff) > options.tolerance) {
int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size();
float abs_tol = options.tolerance;
float rel_tol = options.tolerance;
for (int64_t i = 0; i < size; ++i) {
float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i));
float abs_diff = fabs(diff);
float abs_ref = fabs((float)vector_Input_Ref.at(i));
float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0;
if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) {
printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i)));
return false;
}
@ -488,80 +471,112 @@ struct Testbed {
/// Verifies the reference matches
bool verify() {
tensor_D.sync_host();
tensor_N.sync_host();
tensor_Softmax.sync_host();
LayoutA layout_A(lda);
LayoutB layout_B(ldb);
LayoutC layout_C(ldc);
LayoutN Layout_N(ldn);
LayoutS Layout_S(lds);
double const kThreshold = options.tolerance;
MatrixCoord extent_A{problem.m(), problem.k()};
MatrixCoord extent_B{problem.k(), problem.n()};
MatrixCoord extent_C{problem.m(), problem.n()};
// Verification checks - set any of these to 'true' to override the verification checks.
bool verified_D = false;
bool verified_N = false;
bool verified_Softmax = false;
for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) {
// Verify softmax output
if (!verified_D) {
cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A);
cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B);
cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C);
cutlass::TensorView<ElementC, LayoutC> view_Ref_device(block_Ref.get(), layout_C, extent_C);
double norm_diff = cutlass::reference::host::TensorNormDiff(
tensor_D.host_view(),
reference_D.host_view());
cutlass::reference::device::GemmComplex<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute, ElementCompute
>(
problem,
options.alpha,
view_A,
cutlass::ComplexTransform::kNone,
view_B,
cutlass::ComplexTransform::kNone,
options.beta,
view_C,
view_Ref_device,
ElementCompute(0)
);
double norm_reference = cutlass::reference::host::TensorNorm(
reference_D.host_view());
// Copy reference results to host memory for verification
std::vector<ElementD> matrix_D_Ref(layout_C.capacity(extent_C));
cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size());
cutlass::TensorView<ElementD, LayoutC> view_Ref(matrix_D_Ref.data(), layout_C, extent_C);
double rel_error = norm_diff / norm_reference;
std::vector<ElementSoftmax> matrix_Softmax_Ref(layout_C.capacity(extent_C));
cutlass::TensorView<ElementSoftmax, LayoutC> view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C);
if (rel_error > kThreshold) {
std::cerr << "\n\nTensor D Relative error: " << rel_error << std::endl;
// Copy computed results to host memory
std::vector<ElementD> matrix_D(layout_C.capacity(extent_C));
cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size());
std::vector<ElementD> matrix_Softmax(layout_C.capacity(extent_C));
cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size());
// Compute the norm
for (int m = 0; m < options.problem_size.m(); ++m) {
reference_N.at({m, 0}) = view_Ref.ref().at({m, 0});
for (int n = 1; n < options.problem_size.n(); ++n) {
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n})));
}
}
else {
verified_D = true;
// Compute softmax
for (int m = 0; m < options.problem_size.m(); ++m) {
float sum = float();
for (int n = 0; n < options.problem_size.n(); ++n) {
sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) );
}
float inv_sum = float(1.0f / sum);
for (int n = 0; n < options.problem_size.n(); ++n) {
view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax(
std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum
);
}
}
}
if (!verified_N) {
verified_N = verify_tensor_N(tensor_N, reference_N);
}
// Verification checks - set any of these to 'true' to override the verification checks.
bool verified_D = false;
bool verified_Softmax = false;
if (!verified_Softmax) {
double norm_diff = cutlass::reference::host::TensorNormDiff(
tensor_Softmax.host_view(),
reference_Softmax.host_view());
double norm_reference = cutlass::reference::host::TensorNorm(
reference_Softmax.host_view());
double rel_error = norm_diff / norm_reference;
if (rel_error > kThreshold) {
std::cerr << "\n\nSoftmax Relative error: " << rel_error << std::endl;
}
else {
verified_Softmax = true;
}
}
if (!verified_D || !verified_N || !verified_Softmax) {
std::cerr << "Verification check failed for tensor Softmax" << std::endl;
emit_results();
// Summarize which checks failed
// Verify softmax output
if (!verified_D) {
std::cerr << "Verification of D tensor failed\n";
}
if (!verified_N) {
std::cerr << "Verification of N tensor failed\n";
verified_D = verify_tensor<ElementC>(matrix_D, matrix_D_Ref);
}
if (!verified_Softmax) {
std::cerr << "Verification of Softmax tensor failed\n";
verified_Softmax = verify_tensor<ElementSoftmax>(matrix_Softmax, matrix_Softmax_Ref);
}
if (!verified_D || !verified_Softmax) {
std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n";
// Summarize which checks failed
if (!verified_D) {
std::cerr << "Verification of D tensor failed\n";
}
if (!verified_Softmax) {
std::cerr << "Verification of Softmax tensor failed\n";
}
return false;
}
return false;
}
return true;
@ -637,14 +652,17 @@ struct Testbed {
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n();
double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9);
double gbytes_per_second = double(bytes) * kIterations / double(elapsed_ms / 1000.0f) / double(1 << 30);
double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9);
double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30);
double elapsed_ms_per_iter = double(elapsed_ms) / kIterations;
std::cout << " Problem: "
<< options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k()
<< ", batch size: " << options.batch_count
<< std::endl;
std::cout << " Runtime: " << elapsed_ms << " ms\n" << std::endl;
std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl;
std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl;
std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl;

View File

@ -29,7 +29,8 @@
*
**************************************************************************************************/
/*! \file
\brief GEMM kernel to support the 'epilogue visitor' model for fusion.
\brief GEMM kernel to support the epilogue visitor model
for customized softmax partial reduction epilogue fusion.
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
its usage has been stabilized. For now, it is included in this example to demonstrate
@ -78,6 +79,7 @@ public:
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
@ -89,6 +91,9 @@ public:
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using ElementNorm = typename EpilogueVisitor::ElementNorm;
using ElementSum = typename EpilogueVisitor::ElementSum;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
@ -121,6 +126,11 @@ public:
TensorRefA ref_A;
TensorRefB ref_B;
TensorRefC ref_C;
TensorRefC ref_D;
ElementNorm *ptr_Max;
ElementSum *ptr_Sum;
int64_t batch_stride_A;
int64_t batch_stride_B;
@ -144,6 +154,10 @@ public:
int batch_count_,
TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefC ref_C_,
TensorRefC ref_D_,
ElementNorm *ptr_Max_,
ElementSum *ptr_Sum_,
int64_t batch_stride_A_,
int64_t batch_stride_B_,
typename EpilogueVisitor::Arguments epilogue_visitor_
@ -153,6 +167,10 @@ public:
batch_count(batch_count_),
ref_A(ref_A_),
ref_B(ref_B_),
ref_C(ref_C_),
ref_D(ref_D_),
ptr_Max(ptr_Max_),
ptr_Sum(ptr_Sum_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
epilogue_visitor(epilogue_visitor_)
@ -174,6 +192,8 @@ public:
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
@ -181,6 +201,11 @@ public:
void * ptr_A;
void * ptr_B;
ElementC * ptr_C;
ElementC * ptr_D;
ElementNorm * ptr_Max;
ElementSum * ptr_Sum;
int64_t batch_stride_A;
int64_t batch_stride_B;
@ -196,11 +221,17 @@ public:
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
batch_stride_A(0),
batch_stride_B(0)
{ }
@ -213,11 +244,17 @@ public:
swizzle_log_tile(0),
params_A(args.ref_A.layout()),
params_B(args.ref_B.layout()),
params_C(args.ref_C.layout()),
params_D(args.ref_D.layout()),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(args.problem_size.k()),
ptr_A(args.ref_A.data()),
ptr_B(args.ref_B.data()),
ptr_C(args.ref_C.data()),
ptr_D(args.ref_D.data()),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
epilogue_visitor(args.epilogue_visitor)
@ -467,7 +504,14 @@ public:
thread_idx,
warp_idx,
lane_idx,
threadblock_offset);
params.params_C,
params.params_D,
params.ptr_C,
params.ptr_D,
params.ptr_Max,
params.ptr_Sum,
threadblock_offset,
blockIdx.y *params.problem_size.m() );
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating

View File

@ -49,10 +49,12 @@
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h"
#include "cutlass/reduction/kernel/reduce_softmax_final.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "epilogue_with_visitor.h"
#include "gemm_with_epilogue_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -209,6 +211,9 @@ private:
int idx_m = block_m + thread_m;
int idx_n = block_n + thread_n;
int batch_offset_norm = block_batch * params.args.batch_stride_N;
int batch_offset_sum = block_batch * params.args.batch_stride_S;
// Kill off thread if it is outside the row boundary
if (params.args.extent.row() <= idx_m) {
return;
@ -251,8 +256,8 @@ private:
params.args.batch_stride_Soft * block_batch +
params.args.ref_Soft.layout()({idx_m, idx_n}));
ElementSum inv_sum = (params.args.ref_S.data())[block_m];
ElementNorm norm = (params.args.ref_N.data())[block_m];
ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum];
ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm];
//
// Loop
@ -281,556 +286,6 @@ private:
}
};
template <
typename ElementNorm_,
typename ElementSum_,
typename ElementSoftmaxCompute_,
typename ThreadblockShape_
>
class ApplyFinalReduction {
public:
using ElementNorm = ElementNorm_;
using ElementSum = ElementSum_;
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
using ThreadblockShape = ThreadblockShape_;
using Layout = cutlass::layout::RowMajor;
using TensorRefN = TensorRef<ElementNorm, Layout>;
using TensorRefSum = TensorRef<ElementSum, Layout>;
//
// Arguments
//
struct Arguments {
MatrixCoord extent; ///< Extent of D and Softmax matrices
int batch_count; ///< Batch count
TensorRefN ref_N; ///< Norm tensor (input / output)
TensorRefSum ref_Sum; ///< Sum tensor (input / output)
int64_t batch_stride_N; ///< Batch stride for N tensor
int64_t batch_stride_Sum; ///< Batch stride for softmax tensor
//
// Methods
//
Arguments():
batch_count(1),
batch_stride_N(0),
batch_stride_Sum(0)
{ }
Arguments(
MatrixCoord extent_, ///< Extent of D and Softmax matrices
int batch_count_, ///< Batch count
TensorRefN ref_N_, ///< Output parameter for N
TensorRefSum ref_Sum_ , ///< Sum
int64_t batch_stride_N_ = 0,
int64_t batch_stride_Sum_ = 0
):
extent(extent_),
batch_count(batch_count_),
ref_N(ref_N_),
ref_Sum(ref_Sum_),
batch_stride_N(batch_stride_N_),
batch_stride_Sum(batch_stride_Sum_)
{
}
};
struct SharedStorage {
};
//
// Params struct
//
struct Params {
Arguments args;
//
// Methods
//
Params() { }
Params(Arguments const &args_): args(args_) { }
};
private:
public:
CUTLASS_DEVICE
ApplyFinalReduction() { }
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
apply(params, shared_storage);
}
private:
/// Partial reduction
CUTLASS_DEVICE
void apply(Params const &params, SharedStorage &shared_storage) {
int threadblock_num = (params.args.extent.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN;
int block_batch = blockIdx.z;
int block_n = blockIdx.x * blockDim.x;
int thread_n = threadIdx.x;
int idx_n = block_n + thread_n;
if (idx_n >= params.args.extent.row()) {
return;
}
using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
using ConvertSum = cutlass::NumericConverter<ElementSoftmaxCompute, ElementSum>;
using ConvertNorm = cutlass::NumericConverter<ElementSoftmaxCompute, ElementNorm>;
ConvertSum convert_sum;
ConvertNorm convert_norm;
ConvertSumOutput convert_sum_output;
ConvertNormOutput convert_norm_output;
ElementNorm *access_n = params.args.ref_N.data() + params.args.batch_stride_N * block_batch + idx_n;
ElementSum *access_s = params.args.ref_Sum.data() + params.args.batch_stride_Sum * block_batch + idx_n;
ElementNorm *access_n_bak = access_n;
ElementSum *access_s_bak = access_s;
uint32_t float_max_bits = 0xff7fffff;
float min_float = reinterpret_cast<float const &>(float_max_bits);
ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float);
ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0);
ElementNorm fetch_n;
ElementSum fetch_s;
CUTLASS_PRAGMA_UNROLL
for (int idx_m = 0; idx_m < threadblock_num; idx_m++) {
arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
max_val = fast_max(max_val, convert_norm(fetch_n));
access_n += params.args.extent.row();
}
access_n = access_n_bak;
CUTLASS_PRAGMA_UNROLL
for (int idx_m = 0; idx_m < threadblock_num; idx_m++) {
arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
arch::global_load<ElementSum, sizeof(ElementSum)>(fetch_s, access_s, true);
sum_val += convert_sum(fetch_s) * fast_exp(convert_norm(fetch_n) - max_val);
access_n += params.args.extent.row();
access_s += params.args.extent.row();
}
ElementSoftmaxCompute inv_sum = cutlass::constants::one<ElementSoftmaxCompute>() / sum_val;
access_n = access_n_bak;
access_s = access_s_bak;
access_n[0] = convert_norm_output(max_val);
access_s[0] = convert_sum_output(inv_sum);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ThreadblockShape_,
int ThreadCount,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementNorm_,
typename ElementSum_,
typename ElementSoftmaxCompute_,
typename ElementwiseFunctor_
>
class EpilogueVisitorBiasMax {
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using ElementNorm = ElementNorm_;
using ElementSum = ElementSum_;
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using SoftmaxFragment = Array<ElementSoftmaxCompute, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
TensorRefD ref_C;
TensorRefD ref_D;
ElementNorm *ptr_Max;
ElementSum *ptr_Sum;
int64_t batch_stride_C;
int64_t batch_stride_D;
int64_t batch_stride_Max;
int64_t batch_stride_Sum;
//
// Methods
//
Arguments():
ptr_Max(nullptr),
ptr_Sum(nullptr),
batch_stride_C(0),
batch_stride_D(0),
batch_stride_Max(0),
batch_stride_Sum(0)
{
}
Arguments(
typename ElementwiseFunctor::Params elementwise_,
TensorRefD ref_C_,
TensorRefD ref_D_,
ElementNorm *ptr_Max_,
ElementSum *ptr_Sum_,
int64_t batch_stride_C_,
int64_t batch_stride_D_,
int64_t batch_stride_Max_,
int64_t batch_stride_Sum_
):
elementwise(elementwise_),
ref_C(ref_C_),
ref_D(ref_D_),
ptr_Max(ptr_Max_),
ptr_Sum(ptr_Sum_),
batch_stride_C(batch_stride_C_),
batch_stride_D(batch_stride_D_),
batch_stride_Max(batch_stride_Max_),
batch_stride_Sum(batch_stride_Sum_)
{
}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
typename OutputTileIterator::Params params_C;
typename OutputTileIterator::Params params_D;
typename OutputTileIterator::Element *ptr_C;
typename OutputTileIterator::Element *ptr_D;
ElementNorm *ptr_Max;
ElementSum *ptr_Sum;
int64_t batch_stride_C;
int64_t batch_stride_D;
int64_t batch_stride_Max;
int64_t batch_stride_Sum;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_D(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
elementwise(args.elementwise),
params_C(args.ref_C.layout()),
params_D(args.ref_D.layout()),
ptr_C(args.ref_C.data()),
ptr_D(args.ref_D.data()),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
batch_stride_Max(args.batch_stride_Max),
batch_stride_Sum(args.batch_stride_Sum)
{
}
};
/// Shared storage
struct SharedStorage {
};
private:
Params const & params_;
SharedStorage & shared_storage_;
MatrixCoord extent_;
ElementwiseFunctor elementwise_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator alpha_;
ElementAccumulator beta_;
ElementSoftmaxCompute accum_max_;
int threadblock_row_;
public:
CUTLASS_DEVICE
EpilogueVisitorBiasMax(
Params const &params, ///< Parameters routed to the epilogue
SharedStorage &shared_storage, ///< Shared storage needed by the functors here
MatrixCoord const &problem_size, ///< Problem size of the output
int thread_idx, ///< Thread index within the threadblock
int warp_idx, ///< Warp index within the threadblock
int lane_idx, ///< Lane index within the warp
MatrixCoord const &threadblock_offset = MatrixCoord(0, 0)
):
params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
iterator_C_(params.params_C, params.ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(params.params_D, params.ptr_D, problem_size, thread_idx, threadblock_offset),
threadblock_row_(threadblock_offset.row())
{
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator()) {
iterator_C_.clear_mask();
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
fragment_C_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int row_idx,
int column_idx,
int frag_idx,
AccumulatorFragment const &accum) {
using Mul = cutlass::multiplies<SoftmaxFragment>;
using Minus = cutlass::minus<SoftmaxFragment>;
using Exp = cutlass::fast_exp_op<SoftmaxFragment>;
Minus minus;
Exp exponential;
SoftmaxFragment result;
using ConvertSumOutput = cutlass::NumericConverter<ElementSoftmaxCompute, ElementSum>;
using ConvertNormOutput = cutlass::NumericConverter<ElementSoftmaxCompute, ElementNorm>;
ConvertSumOutput convert_sum_output;
ConvertNormOutput convert_norm_output;
NumericArrayConverter<ElementSoftmaxCompute, ElementOutput, kElementsPerAccess> source_converter;
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
result = source_converter(elementwise_(accum));
}else{
result = source_converter(elementwise_(accum, source_vector));
}
MatrixCoord thread_offset =
iterator_D_.thread_start() +
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
int thread_in_row = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth;
int half_thread_in_row = (thread_in_row >> 1);
bool column_guard = (thread_offset.column() < extent_.column());
// Compute the maximum within one row
if (!column_idx) {
// This is the first fragment in a new row
if (column_guard) {
accum_max_ = maximum_accumulator_(result);
}
}
else {
// This is an additional fragment in the same row
if (column_guard) {
accum_max_ = maximum_accumulator_(result, accum_max_);
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = half_thread_in_row; i > 0; i >>= 1) {
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, accum_max_, i);
accum_max_ = fast_max(accum_max_, tmp);
}
SoftmaxFragment sum_frag = exponential(minus(result, accum_max_));
ElementSoftmaxCompute reduction_sum = sum_accumulator_(sum_frag);
CUTLASS_PRAGMA_UNROLL
for (int i = half_thread_in_row; i > 0; i >>= 1) {
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, reduction_sum, i);
reduction_sum += tmp;
}
bool is_write_thread = (thread_offset.row() < extent_.row() && (threadIdx.x % thread_in_row) == 0);
ElementNorm *curr_ptr_max = params_.ptr_Max + thread_offset.row() + blockIdx.y * extent_.row();
ElementSum *curr_ptr_sum = params_.ptr_Sum + thread_offset.row() + blockIdx.y * extent_.row();
arch::global_store<ElementNorm, sizeof(ElementNorm)>(
convert_norm_output(accum_max_),
(void *)curr_ptr_max,
is_write_thread);
arch::global_store<ElementSum, sizeof(ElementSum)>(
convert_sum_output(reduction_sum),
(void *)curr_ptr_sum,
is_write_thread);
clear_accum_max_();
// Convert to the output
NumericArrayConverter<ElementOutput, ElementSoftmaxCompute, kElementsPerAccess> output_converter;
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the start of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
}
private:
CUTLASS_DEVICE
void clear_accum_max_() {
uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX
float min_float = reinterpret_cast<float const &>(float_max_bits);
accum_max_ = ElementSoftmaxCompute(min_float);
}
CUTLASS_DEVICE
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) {
ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
sum_ += ElementSoftmaxCompute(accum[i]);
}
return sum_;
}
CUTLASS_DEVICE
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) {
ElementSoftmaxCompute max_ = accum[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < SoftmaxFragment::kElements; ++i) {
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
}
return max_;
}
CUTLASS_DEVICE
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
}
return max_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -846,10 +301,18 @@ template <
typename LayoutB_,
typename ElementC_,
typename ElementCompute_,
typename OperatorClass_,
typename ArchTag_,
typename ThreadblockShape_,
typename WarpShape_,
typename InstructionShape_,
typename EpilogueFunctorOp_,
int kStages_,
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
typename ElementNorm_ = float,
typename ElementSum_ = float,
int Alignment = 128 / cutlass::sizeof_bits<ElementA_>::value,
typename ElementSoftmax_ = ElementC_
>
class GemmSoftmax {
@ -872,8 +335,6 @@ public:
using LayoutA = LayoutA_;
using LayoutB = LayoutB_;
static int const kAlignment = Alignment;
using EpilogueFunctorOp = EpilogueFunctorOp_;
using ElementNorm = ElementNorm_;
@ -890,13 +351,17 @@ public:
using TensorRefSum = TensorRef<ElementSum, LayoutS>;
using TensorRefSoft = TensorRef<ElementSoft, LayoutSoft>;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80;
static int const kStages = 3;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
static int const kStages = kStages_;
static int const AlignmentA = AlignmentA_;
static int const AlignmentB = AlignmentB_;
static int const AlignmentSoftmax = AlignmentSoftmax_;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
@ -906,10 +371,10 @@ public:
using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm<
ElementA,
LayoutA,
kAlignment,
AlignmentA,
ElementB,
LayoutB,
kAlignment,
AlignmentB,
ElementC,
LayoutC,
ElementCompute,
@ -930,7 +395,7 @@ public:
///////////////////////////////////////////////////////////////////////////////////////////////
// Epilogue visitor
using EpilogueVisitor = kernel::EpilogueVisitorBiasMax<
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax<
ThreadblockShape,
DefaultGemmKernel::kThreadCount,
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
@ -961,13 +426,13 @@ public:
ElementSum,
ElementSoft,
ElementSoftmaxCompute,
kAlignment,
AlignmentSoftmax,
MatrixShape<
1, 1024
>
>;
using ApplyFinalReductionKernel = kernel::ApplyFinalReduction<
using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
ElementNorm,
ElementSum,
ElementSoftmaxCompute,
@ -983,6 +448,7 @@ public:
typename SoftmaxApplyKernel::Arguments softmax;
typename ApplyFinalReductionKernel::Arguments reduction;
cutlass::gemm::GemmCoord extend;
//
// Methods
//
@ -1013,14 +479,14 @@ public:
batch_count_,
ref_A_,
ref_B_,
ref_C_,
ref_D_,
ref_N_.data(),
ref_S_.data(),
batch_stride_A_,
batch_stride_B_,
typename EpilogueVisitor::Arguments(
linear_scaling,
ref_C_,
ref_D_,
ref_N_.data(),
ref_S_.data(),
batch_stride_C_,
batch_stride_D_,
batch_stride_Max_,
@ -1028,10 +494,9 @@ public:
)
),
reduction(
MatrixCoord(problem_size.m(), problem_size.n()),
batch_count_,
ref_N_,
ref_S_,
problem_size,
ref_N_.data(),
ref_S_.data(),
batch_stride_Max_,
batch_stride_Sum_
),
@ -1127,28 +592,24 @@ public:
// Launch the ApplyFinalReductionKernel
//
int threadblock_num_in_column = (params_.extend.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN;
int thread_per_block = 128;
int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
if (block_per_row < 4) {
thread_per_block = 32;
block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
}
if (threadblock_num_in_column > 1) {
int thread_per_block = 128;
int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
if (block_per_row < 4) {
thread_per_block = 32;
block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
}
dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count);
dim3 final_reduction_block(thread_per_block);
dim3 final_reduction_grid(block_per_row);
dim3 final_reduction_block(thread_per_block);
Kernel<ApplyFinalReductionKernel><<<
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream
>>>(params_.reduction);
Kernel<ApplyFinalReductionKernel><<<
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream
>>>(params_.reduction);
result = cudaGetLastError();
result = cudaGetLastError();
if (result != cudaSuccess) {
return cutlass::Status::kErrorInternal;
}
if (result != cudaSuccess) {
return cutlass::Status::kErrorInternal;
}
//

View File

@ -40,18 +40,17 @@
// for (int j = 0; j < options.index_size; ++j) {
// int b_c_d_col = tensor_indices.at({j, 0});
//
// for (int k = 0; k < problem_size.k(); ++k) {
// for (int k = 0; k < options.index_size; ++k) {
// tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
// }
// }
// }
//
// Note that the index vector contains unique random integers with max to be N - 1
//
// The gather/scatter operation works best when we can still keep the biggest
// alignment. For example, when the matrix is row major, we select rows. When
// the matrix is column major, we selct columns.
// the matrix is column major, we select columns.
//
// Not all the combination of gather and scatter are legal. For example, if A is
// row major and C/D is column major, we cannot gather A and scatter C/D at the
@ -257,7 +256,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
false, /*GatherA*/
false, /*GatherA*/
true, /*GatherB*/
true /*ScatterD*/
>;
@ -353,7 +352,7 @@ int run(Options &options) {
tensor_b.layout().stride(),
tensor_c.layout().stride(),
tensor_d_scattered.layout().stride(),
nullptr, // <- pointer to index vector to gather A on device
nullptr, // <- pointer to index vector to gather A on device
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
@ -392,7 +391,7 @@ int run(Options &options) {
tensor_d_ref.at({i, b_c_d_col}) +=
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
}
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
}
}
@ -515,7 +514,7 @@ int main(int argc, const char ** argv) {
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) {
if (!(props.major >= 8)) {
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
37_gemm_layernorm_gemm_fusion
gemm_layernorm.cu
)

View File

@ -0,0 +1,937 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief CUTLASS Layernorm Example.
This workload provides a layer normalization example using a one-pass, square-sum-based
variance calculation. Specifically, we fuse the reduction operation to find
local mean and local square sum mean in the epilogue of 1st GEMM. After a light
full reduction kernel, the mean / variance values are readily calculated for element-wise
operations which are fused into the 2nd GEMM.
As stated in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data,
the square-sum based one-pass implementation may raise concerns on numerical stability issues.
That being said, though this fully fused layernorm example almost perfectly hides all the memory cost to
access the intermediate matrix for layernorm computation, the numerical issue might hinder a persuasive
usage in real-world scenarios. If that is the case, a user may turn to the stand-alone CUTLASS layernorm
example in tools/util/include/cutlass/util/device_layernorm.h
Examples:
# Run a CUTLASS layernorm example with default setup ,
# using the language of the transformer model as an example,
(Column Major output matrix, hidden dimension = 768, valid word number = 4096, intermediate_scale = 4)
$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion
# Run an attention example with hidden dimension = 512
$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion --hidden_dim=512
*/
#include <cmath>
#include <iostream>
#include <vector>
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/device/gemm_complex.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "cutlass/util/reference/host/tensor_reduce.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/fast_math.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "gemm_with_layernorm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
enum class Disposition {
kPassed,
kIncorrect,
kNotVerified
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
template<typename LayoutOutput_>
struct Options {
using LayoutOutput = LayoutOutput_;
static bool const kIsColumnMajorOutput = cutlass::platform::is_same<LayoutOutput, cutlass::layout::ColumnMajor>::value;
bool help;
cutlass::gemm::GemmCoord problem_size0;
cutlass::gemm::GemmCoord problem_size1;
int hidden_dim;
int valid_word_num;
int intermediate_scale;
int iterations;
unsigned seed;
float alpha;
float beta;
bool verification_enabled;
double tolerance;
Options():
help(false),
iterations(20),
seed(2022),
hidden_dim(768),
valid_word_num(4096),
intermediate_scale(4),
alpha(1),
beta(0),
verification_enabled(true),
tolerance(0.01),
problem_size1(problem_size0.m() * 4, problem_size0.n(), problem_size0.m())
{ }
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("hidden_dim", hidden_dim, 768);
cmd.get_cmd_line_argument("valid_word_num", valid_word_num, 4096);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("verify", verification_enabled);
cmd.get_cmd_line_argument("seed", seed);
cmd.get_cmd_line_argument("tolerance", tolerance);
if (kIsColumnMajorOutput) {
// column major output setup
problem_size0.m() = hidden_dim;
problem_size0.n() = valid_word_num;
problem_size0.k() = hidden_dim;
problem_size1.m() = hidden_dim * intermediate_scale;
problem_size1.n() = valid_word_num;
problem_size1.k() = hidden_dim;
}else{
// row major output setup
problem_size0.m() = valid_word_num;
problem_size0.n() = hidden_dim;
problem_size0.k() = hidden_dim;
problem_size1.m() = valid_word_num;
problem_size1.n() = hidden_dim * intermediate_scale;
problem_size1.k() = hidden_dim;
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "37_gemm_layernorm_gemm_fusion example\n\n"
<< " This example uses the CUTLASS Library to compute GEMM + Layernorm for arbitrary problem sizes.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --hidden_dim=<int> Hidden dimension\n"
<< " --valid_word_num=<int> Valid word number\n"
<< " --seed=<int> Random number seed (1*)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform (0 to disable profiling).\n\n"
<< " --verify=<bool> If true, performs reference calculation.\n\n"
<< " --tolerance <float> Error tolerance\n"
;
out << "\n\nExamples:\n\n"
<< "$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion \\\n"
<< " --hidden_dim=768 --valid_word_num=1024 \n\n";
return out;
}
/// Returns true if the environment and Toolkit support this
bool supported(bool verbose = true) const {
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
if (verbose) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
}
return false;
}
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
if (verbose) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
}
return false;
}
if (!((props.major * 10 + props.minor) >= 80)) {
if (verbose) {
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
}
return false;
}
//
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((problem_size0.m() % kAlignment) ||
(problem_size0.n() % kAlignment) ||
(problem_size0.k() % kAlignment)) {
if (verbose) {
std::cerr << "Misaligned input in 1st GEMM." << std::endl;
}
// misaligned tensors for Gemm1
return false;
}
if ((problem_size1.m() % kAlignment) ||
(problem_size1.n() % kAlignment) ||
(problem_size1.k() % kAlignment)) {
if (verbose) {
std::cerr << "Misaligned input in 2nd GEMM." << std::endl;
}
// misaligned tensors for Gemm2
return false;
}
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename LayoutOutput_>
struct Testbed {
//
// Type definitions
//
// User-defined data types
using ElementInputA0 = cutlass::half_t;
using ElementInputB0 = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementCompute = cutlass::half_t;
using LayoutInputA0 = cutlass::layout::RowMajor;
using LayoutInputB0 = cutlass::layout::ColumnMajor;
using LayoutOutput = LayoutOutput_;
static bool const kIsColumnMajorOutput = cutlass::platform::is_same<LayoutOutput, cutlass::layout::ColumnMajor>::value;
// turn of shifted K by default
static bool const kIsShiftedVariance = false;
/// Linear scaling operator
using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementCompute,
ElementCompute
>;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
static int const kStages0 = 3;
static int const kStages1 = 4;
using GemmLayernorm = cutlass::GemmLayernorm<
ElementInputA0,
LayoutInputA0,
ElementInputB0,
LayoutInputB0,
ElementOutput,
LayoutOutput,
ElementCompute,
EpilogueFunctorOp,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages0,
kStages1,
kIsShiftedVariance
>;
using ElementInputA1 = typename GemmLayernorm::ElementInputA1;
using ElementOutputC1 = typename GemmLayernorm::ElementOutputC1;
using ElementInputScaleBias = typename GemmLayernorm::ElementInputScaleBias;
using ElementLayernormCompute = typename GemmLayernorm::ElementLayernormCompute;
using LayoutInputA1 = typename GemmLayernorm::LayoutInputA1;
using LayoutOutputC0 = typename GemmLayernorm::LayoutOutputC0;
using LayoutOutputC1 = typename GemmLayernorm::LayoutOutputC1;
using LayoutInputScaleBias = typename GemmLayernorm::LayoutInputScaleBias;
//
// Data members
//
Options<LayoutOutput> const &options;
cutlass::HostTensor<ElementInputA0, LayoutInputA0> tensor_A0;
cutlass::HostTensor<ElementInputB0, LayoutInputB0> tensor_B0;
cutlass::HostTensor<ElementOutput, LayoutOutputC0> tensor_C0;
cutlass::HostTensor<ElementInputA1, LayoutInputA1> tensor_A1;
cutlass::HostTensor<ElementOutputC1, LayoutOutputC1> tensor_C1;
cutlass::HostTensor<ElementOutput, LayoutOutputC0> reference_C0;
cutlass::HostTensor<ElementOutputC1, LayoutOutputC1> reference_C1;
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Variance;
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Mean;
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Beta;
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> tensor_Gamma;
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> reference_Mean;
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias> reference_Variance;
// shifted K tensor to better ensure the numerical stability
// According to https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// the closer shifted K to the actual mean, the better numerical stability we'll observe
cutlass::HostTensor<ElementOutput, LayoutOutputC0> tensor_Shifted_K;
//
// Methods
//
Testbed(
Options<LayoutOutput> const &options_
):
options(options_)
{
tensor_A0.reset({options.problem_size0.m(), options.problem_size0.k()});
tensor_B0.reset({options.problem_size0.k(), options.problem_size0.n()});
tensor_C0.reset({options.problem_size0.m(), options.problem_size0.n()});
tensor_A1.reset({options.problem_size1.m(), options.problem_size1.k()});
tensor_C1.reset({options.problem_size1.m(), options.problem_size1.n()});
reference_C0.reset({options.problem_size0.m(), options.problem_size0.n()});
reference_C1.reset({options.problem_size1.m(), options.problem_size1.n()});
int leading_dim_0 = kIsColumnMajorOutput ? options.problem_size0.n() : options.problem_size0.m();
int leading_dim_1 = kIsColumnMajorOutput ? options.problem_size0.m() : options.problem_size0.n();
int block_num = (leading_dim_1 + GemmLayernorm::ThreadblockShape::kM - 1) / GemmLayernorm::ThreadblockShape::kM;
tensor_Variance.reset({block_num, leading_dim_0});
tensor_Mean.reset({block_num, leading_dim_0});
tensor_Shifted_K.reset({1, leading_dim_0});
tensor_Beta.reset({1, leading_dim_1});
tensor_Gamma.reset({1, leading_dim_1});
reference_Mean.reset({1, leading_dim_0}, false);
reference_Variance.reset({1, leading_dim_0}, false);
}
/// Run
Disposition run() {
Disposition disposition = Disposition::kNotVerified;
//
// Initialize the workspace
//
initialize();
//
// Launch device kernel
//
cutlass::Status status = cutlass::Status::kSuccess;
status = execute_device_kernel();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Device execution failed." << std::endl;
return disposition;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Device synchronize failed with error "
<< cudaGetErrorString(result) << std::endl;
return disposition;
}
//
// Compute the reference
//
compute_reference();
//
// Verify
//
if (options.verification_enabled) {
bool passed = verify();
if (passed) {
disposition = Disposition::kPassed;
}
else {
disposition = Disposition::kIncorrect;
}
}
//
// Profiling
//
if (options.iterations) {
profile();
}
return disposition;
}
/// Random initialization
void initialize() {
cutlass::reference::host::TensorFillRandomUniform(
tensor_A0.host_view(),
options.seed,
ElementInputA0(5),
ElementInputA0(-5),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_B0.host_view(),
options.seed + 1,
ElementInputB0(5),
ElementInputB0(-5),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_A1.host_view(),
options.seed + 2,
ElementInputA1(5),
ElementInputA1(-5),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_Beta.host_view(),
options.seed + 3,
ElementInputScaleBias(5),
ElementInputScaleBias(-5),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_Gamma.host_view(),
options.seed + 4,
ElementInputScaleBias(5),
ElementInputScaleBias(-5),
0
);
cutlass::reference::host::TensorFillRandomUniform(
tensor_Shifted_K.host_view(),
options.seed + 5,
ElementOutput(5),
ElementOutput(-6),
0
);
tensor_A0.sync_device();
tensor_B0.sync_device();
tensor_A1.sync_device();
tensor_Beta.sync_device();
tensor_Gamma.sync_device();
}
cutlass::Status execute_device_kernel() {
cutlass::Status status = cutlass::Status::kSuccess;
//
// Setup arguments
//
typename GemmLayernorm::Arguments args(
options.problem_size0,
options.problem_size1,
tensor_A0.device_ref().data(),
tensor_B0.device_ref().data(),
tensor_C0.device_ref().data(),
tensor_C0.device_ref().data(),
tensor_A1.device_ref().data(),
tensor_C1.device_ref().data(),
tensor_A0.device_ref().stride(0),
tensor_B0.device_ref().stride(0),
tensor_C0.device_ref().stride(0),
tensor_C0.device_ref().stride(0),
tensor_A1.device_ref().stride(0),
tensor_C1.device_ref().stride(0),
{
ElementCompute(options.alpha),
ElementCompute(options.beta)
},
tensor_Variance.device_ref(),
tensor_Mean.device_ref(),
tensor_Gamma.device_ref(),
tensor_Beta.device_ref(),
tensor_Shifted_K.device_ref().data()
);
//
// Launch
//
GemmLayernorm gemm_layernorm;
// Initialize
status = gemm_layernorm.initialize(args);
if (status != cutlass::Status::kSuccess) {
return status;
}
// Run
status = gemm_layernorm();
return status;
}
/// Reference calculation
void compute_reference() {
cutlass::reference::device::Gemm<
ElementInputA0,
LayoutInputA0,
ElementInputB0,
LayoutInputB0,
ElementOutput,
LayoutOutputC0,
ElementCompute,
ElementCompute
> gemm_device0;
cutlass::reference::device::Gemm<
ElementInputA1,
LayoutInputA1,
ElementOutput,
LayoutOutputC0,
ElementOutputC1,
LayoutOutputC1,
ElementCompute,
ElementCompute
> gemm_device1;
// Compute 1st GEMM
gemm_device0(
options.problem_size0,
ElementCompute(options.alpha),
tensor_A0.device_ref(),
tensor_B0.device_ref(),
ElementCompute(options.beta),
tensor_C0.device_ref(),
reference_C0.device_ref()
);
reference_C0.sync_host();
tensor_Mean.sync_host();
tensor_Variance.sync_host();
tensor_Gamma.sync_host();
tensor_Beta.sync_host();
tensor_Shifted_K.sync_host();
// Compute the sum and square sum for verification purpose
if (kIsColumnMajorOutput) {
for (int n = 0; n < options.problem_size0.n(); ++n) {
ElementLayernormCompute sum = ElementLayernormCompute(0);
ElementLayernormCompute square_sum = ElementLayernormCompute(0);
for (int m = 0; m < options.problem_size0.m(); ++m) {
sum += ElementLayernormCompute(reference_C0.at({m, n}));
square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n}));
}
ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.m());
ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.m());
ElementLayernormCompute variance = cutlass::constants::one<ElementLayernormCompute>() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6) ) ;
mean = -mean * variance;
reference_Mean.at({0, n}) = ElementInputScaleBias(mean);
reference_Variance.at({0, n}) = ElementInputScaleBias(variance);
}
}else{
for (int m = 0; m < options.problem_size0.m(); ++m) {
ElementLayernormCompute sum = ElementLayernormCompute(0);
ElementLayernormCompute square_sum = ElementLayernormCompute(0);
for (int n = 0; n < options.problem_size0.n(); ++n) {
sum += ElementLayernormCompute(reference_C0.at({m, n})) ;
square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})) ;
}
ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.n());
ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.n());
ElementLayernormCompute variance = cutlass::constants::one<ElementLayernormCompute>() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)) ;
mean = -mean * variance;
reference_Mean.at({0, m}) = ElementInputScaleBias(mean);
reference_Variance.at({0, m}) = ElementInputScaleBias(variance);
}
}
// Element-wise transform for OutputC0 using 1-pass layernorm algo
if (kIsColumnMajorOutput) {
for (int n = 0; n < options.problem_size0.n(); ++n) {
ElementLayernormCompute sum = ElementLayernormCompute(0);
for (int m = 0; m < options.problem_size0.m(); ++m) {
sum += ElementLayernormCompute(reference_C0.at({m, n})) ;
}
ElementInputScaleBias mean = ElementInputScaleBias(sum / ElementLayernormCompute(options.problem_size0.m()));
sum = ElementLayernormCompute(0);
for (int m = 0; m < options.problem_size0.m(); ++m) {
sum += ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) * ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) ;
}
ElementLayernormCompute square_mean = sum / ElementLayernormCompute(options.problem_size0.m());
ElementInputScaleBias variance = ElementInputScaleBias(cutlass::constants::one<ElementLayernormCompute>()
/ cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6))) ;
for (int m = 0; m < options.problem_size0.m(); ++m) {
reference_C0.at({m, n}) =
ElementOutput( ( (ElementInputScaleBias(reference_C0.at({m, n})) - mean) * variance )
* tensor_Gamma.at({0, m}) + tensor_Beta.at({0, m}));
}
}
}else{
for (int m = 0; m < options.problem_size0.m(); ++m) {
float sum = float(0);
for (int n = 0; n < options.problem_size0.n(); ++n) {
sum += float(reference_C0.at({m, n})) ;
}
float mean = sum / float(options.problem_size0.n());
sum = float(0);
for (int n = 0; n < options.problem_size0.n(); ++n) {
sum += float(reference_C0.at({m, n}) - mean) * float(reference_C0.at({m, n}) - mean) ;
}
float square_mean = sum / float(options.problem_size0.n());
float variance = cutlass::constants::one<float>() / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6)) ;
for (int n = 0; n < options.problem_size0.n(); ++n) {
reference_C0.at({m, n}) =
ElementOutput( ( (float(reference_C0.at({m, n})) - mean) * variance )
* float(tensor_Gamma.at({0, n})) + float(tensor_Beta.at({0, n})));
}
}
}
// Sync host data with device after element-wise transform
reference_C0.sync_device();
// Compute 2nd GEMM
gemm_device1(
options.problem_size1,
ElementCompute(options.alpha),
kIsColumnMajorOutput ? tensor_A1.device_ref() : reference_C0.device_ref(),
kIsColumnMajorOutput ? reference_C0.device_ref() :tensor_A1.device_ref(),
ElementCompute(options.beta),
reference_C1.device_ref(),
reference_C1.device_ref()
);
}
/// Emits all tensor values
void emit_results() {
std::cout << "tensor_C1 = \n" << tensor_C1.host_view() << "\n\n";
std::cout << "Reference C1 = \n" << reference_C1.host_view() << "\n\n";
std::cout << "Mean = \n" << tensor_Mean.host_view() << "\n\n";
std::cout << "rsqrt(Variance) = \n" << tensor_Variance.host_view() << "\n\n";
std::cout << "Reference Mean = \n" << reference_Mean.host_view() << "\n\n";
std::cout << "Reference rsqrt(Variance) = \n" << reference_Variance.host_view() << "\n\n";
}
template<typename Element, typename Layout>
bool verify_tensor(cutlass::HostTensor<Element, Layout> tensor, \
cutlass::HostTensor<Element, Layout> reference,
int leading_dim0, int leading_dim1, bool is_print = false) {
float const kThreshold = float(options.tolerance);
float const kAbsThreshold = 0.5f;
float const kRelativeThreshold = 0.1f;
// Adds a constant bias to avoid being divided by '0'
float const kBias = 1e-5f;
int counter = 0;
for (int m = 0; m < leading_dim0; m++) {
for (int n = 0; n < leading_dim1; ++n) {
float diff = (float)(tensor.at({m, n}) - reference.at({m, n}));
float rel_diff = fabs(diff) / fabs(reference.at({m, n}) + kBias);
if (fabs(diff) > kAbsThreshold && rel_diff > kRelativeThreshold) {
counter++;
}
}
}
float err_rate = float(counter) / (float(leading_dim0) * float(leading_dim1));
return (err_rate < kThreshold);
}
/// Verifies the reference matches
bool verify() {
tensor_Variance.sync_host();
tensor_Mean.sync_host();
tensor_C1.sync_host();
reference_C1.sync_host();
// Verification checks - set any of these to 'true' to override the verification checks.
bool verified_C1 = false;
bool verified_Mean = false;
bool verified_Variance = false;
// Verify layernorm output
if (!verified_C1) {
verified_C1 = verify_tensor<ElementOutputC1, LayoutOutputC1>(tensor_C1, reference_C1, options.problem_size1.m(), options.problem_size1.n());
}
if (!verified_Variance) {
verified_Variance = verify_tensor<ElementInputScaleBias, LayoutInputScaleBias>(tensor_Variance, reference_Variance, 1, options.problem_size0.n());
}
if (!verified_Mean) {
verified_Mean = verify_tensor<ElementInputScaleBias, LayoutInputScaleBias>(tensor_Mean, reference_Mean, 1, options.problem_size0.n());
}
if (!verified_C1 || !verified_Mean || !verified_Variance) {
// emit_results();
std::cerr << "Verification check failed for tensor Layernorm" << std::endl;
// Summarize which checks failed
if (!verified_C1) {
std::cerr << "Verification of O tensor failed\n";
}
if (!verified_Mean) {
std::cerr << "Verification of Mean tensor failed\n";
}
if (!verified_Variance) {
std::cerr << "Verification of Variance tensor failed\n";
}
return false;
}
return true;
}
/// Profiles
bool profile() {
//
// Profile
//
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t result;
cudaEvent_t events[2];
int const kIterations = options.iterations;
for (cudaEvent_t &evt : events) {
result = cudaEventCreate(&evt);
if (result != cudaSuccess) {
std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
}
result = cudaEventRecord(events[0]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
for (int iter = 0; iter < kIterations; ++iter) {
status = execute_device_kernel();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Device execution failed." << std::endl;
return false;
}
}
result = cudaEventRecord(events[1]);
if (result != cudaSuccess) {
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
float elapsed_ms = 0;
result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]);
float elapsed_ms_per_iter = elapsed_ms / float(kIterations);
if (result != cudaSuccess) {
std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
for (cudaEvent_t &evt : events) {
result = cudaEventDestroy(evt);
if (result != cudaSuccess) {
std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl;
return false;
}
}
int64_t flops = int64_t(options.problem_size0.m()) * options.problem_size0.n() * options.problem_size0.k() * 2 \
+ int64_t(options.problem_size1.m()) * options.problem_size1.n() * options.problem_size1.k() * 2;
double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9);
std::cout << " 1st GEMM: "
<< options.problem_size0.m() << "-by-" << options.problem_size0.n() << "-by-" << options.problem_size0.k() << "\n"
<< " 2nd GEMM: "
<< options.problem_size1.m() << "-by-" << options.problem_size1.n() << "-by-" << options.problem_size1.k()
<< std::endl;
std::cout << " Runtime / iteration: " << elapsed_ms_per_iter << " ms\n" << std::endl;
std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl;
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, const char **argv) {
// Define final layout
using LayoutOutput = cutlass::layout::ColumnMajor;
// Options parsing
Options<LayoutOutput> options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (!options.supported()) {
return 0;
}
// Run
Testbed<LayoutOutput> testbed(options);
Disposition disposition = testbed.run();
std::cout << std::endl;
switch (disposition) {
case Disposition::kPassed:
std::cout << "Passed" << std::endl;
break;
case Disposition::kIncorrect:
std::cout << "Incorrect" << std::endl;
break;
case Disposition::kNotVerified:
std::cout << "Not verified" << std::endl;
break;
}
return (disposition == Disposition::kPassed ? 0 : -1);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,450 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM kernel to support the epilogue visitor model
for customized layernorm partial reduction epilogue fusion.
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
its usage has been stabilized. For now, it is included in this example to demonstrate
some basic output fusion options.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value
);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
TensorRefA ref_A;
TensorRefB ref_B;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm)
{ }
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_,
TensorRefA ref_A_,
TensorRefB ref_B_,
typename EpilogueVisitor::Arguments epilogue_visitor_
):
mode(mode_),
problem_size(problem_size_),
ref_A(ref_A_),
ref_B(ref_B_),
epilogue_visitor(epilogue_visitor_)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
GemmUniversalMode mode;
int gemm_k_size;
void * ptr_A;
void * ptr_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
swizzle_log_tile(0),
params_A(0),
params_B(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr)
{ }
Params(
Arguments const &args
):
problem_size(args.problem_size),
swizzle_log_tile(0),
params_A(args.ref_A.layout()),
params_B(args.ref_B.layout()),
mode(args.mode),
gemm_k_size(args.problem_size.k()),
ptr_A(args.ref_A.data()),
ptr_B(args.ref_B.data()),
epilogue_visitor(args.epilogue_visitor)
{
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(args.problem_size.k(), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
struct {
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.epilogue.visitor,
params.problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
threadblock_offset);
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(
shared_storage.epilogue.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
38_syr2k_grouped
syr2k_grouped.cu
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
39_gemm_permute
gemm_permute.cu
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,162 @@
# CUTLASS Python Interface Example
## Using Docker
You can run the PyCUTLASS on NGC pytorch container.
```shell
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.08-py3
```
PyCUTLASS requires additional dependency Boost C++ library, which can be installed with
```bash
apt-get update
apt-get -y install libboost-all-dev
```
## Install the Python Interface
The source code for python interface is allocated at `tools/library/script/pycutlass`. It requires two environment variables:
* `CUTLASS_PATH`: the root directory of CUTLASS
* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed
After setting these two environment variables, PyCUTLASS can be installed with
```shell
cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh
```
***
## Troubleshooting
### Issue 1: permission denied
Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission.
### Issue 2: rmm: module not found
PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps:
```shell
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python
python setup.py build_ext --inplace
python setup.py install
```
To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues.
***
For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message.
## GEMM Examples
The GEMM examples use numpy to create input tensors and verify the results.
### GEMM F64 Example
Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial
```python
python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
```
### GEMM F32 Example
Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2
```
Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial
```python
python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4
```
### GEMM F16 Example
Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1
```
Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2
```
Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial
```python
python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3
```
### GEMM BF16 Example
Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel
```python
python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5
```
### GEMM Int8 Example
Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128
```python
python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1
```
***
## GEMM Grouped Examples
The GEMM Grouped examples use numpy to create input tensors and verify the results.
Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device
```
Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule
```python
python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle2 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host
```
Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule
```python
python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule
```python
python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle8 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device
```
***
## Conv2d Example
The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/#:~:text=Aid%20to%20Ukraine.-,INSTALL%20PYTORCH,-Select%20your%20preferences).
### Conv2d F32 Fprop
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0
```
Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
```python
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0
```
### Conv2d F32 Wgrad
Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32
```python
python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
### Conv2d F32 Dgrad
Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
### Conv2d F16 Fprop
Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```
Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4
```python
python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0
```

View File

@ -0,0 +1,277 @@
################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import pycutlass
from pycutlass import *
from pycutlass.conv2d_operation import *
from pycutlass.utils import reference_model
import argparse
# parse the arguments
parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from python")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'],
help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[
4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorC32RSK32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[
"TensorNHWC", "TensorNC32HW32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'],
help='Data type of computation in the epilogue')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8",
"HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4",
"StridedDgradHorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# conv related
parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'],
help="The type of convolution: forward propagation (fprop), \
gradient of activation (dgrad), gradient of weight (wgrad)")
parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"],
)
parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str,
choices=["analytic", "optimized", "fixed_channels", "few_channels"],
help="This option describes iterator algorithm")
# arguments
parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"],
help="Split K Mode. Serial is used for non-splitK or serial-splitK.\
Parallel is used for parallel splitK.")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)")
parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)")
parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)")
parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)")
parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
stride_support = getattr(StrideSupport, args.stride_support)
conv_kind = getattr(cutlass.conv.Operator, args.conv_kind)
operation = Conv2dOperation(
conv_kind=conv_kind, iterator_algorithm=iterator_algorithm,
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue, stride_support=stride_support,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
if args.print_cuda:
print(operation.rt_module.emit())
operations = [operation,]
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
count=C.alignment
)
operations.append(reduction_operation)
pycutlass.compiler.add_module(operations)
problem_size = cutlass.conv.Conv2dProblemSize(
cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]),
cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]),
cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]),
cutlass.MatrixCoord(args.stride[0], args.stride[1]),
cutlass.MatrixCoord(args.dilation[0], args.dilation[1]),
cutlass.conv.Mode.cross_correlation,
args.split_k_slices, 1
)
# User-provide inputs
tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(
conv_kind, problem_size
)
tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(
conv_kind, problem_size
)
tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(
conv_kind, problem_size
)
if args.element_a != "int8":
tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2)
if args.element_b != "int8":
tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2)
if args.element_c != "int8":
tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5))
else:
tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2)
tensor_D = torch.ones_like(tensor_C)
arguments = Conv2dArguments(
operation=operation, problem_size=problem_size, A=tensor_A,
B=tensor_B, C=tensor_C, D=tensor_D,
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta),
split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode),
split_k_slices=problem_size.split_k_slices
)
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size)
reduction_arguments = ReductionArguments(
reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
partitions=problem_size.split_k_slices,
workspace=arguments.ptr_D,
destination=tensor_D,
source=tensor_C,
output_op = LinearCombinationFunctorArguments(args.alpha, args.beta)
)
operation.run(arguments)
if args.split_k_mode == "Parallel" and args.split_k_slices > 1:
reduction_operation.run(reduction_arguments)
reduction_arguments.sync()
else:
arguments.sync()
reference_model = Conv2dReferenceModule(A, B, C, conv_kind)
tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta)
assert torch.equal(tensor_D, tensor_D_ref)
print("Passed.")

View File

@ -0,0 +1,266 @@
################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import numpy as np
import pycutlass
from pycutlass import *
import cutlass
from bfloat16 import bfloat16
import argparse
# parse the arguments
parser = argparse.ArgumentParser(
description="Launch CUTLASS GEMM kernels from python: 'D = alpha * A * B + beta * C'")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'],
help="This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM")
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignement of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# Argument
parser.add_argument("-p", "--problem_size",
default=[128, 128, 128], nargs=3, type=int,
help="GEMM problem size M, N, K")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float,
help="Scaling factor of A * B")
parser.add_argument("-beta", "--beta", default=0.0, type=float,
help="Scaling factor of C")
parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str,
choices=["Gemm", "GemmSplitKParallel"],
help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \
GemmSplitKParallel is used for parallel splitK")
parser.add_argument('-k', '--split_k_slices', default=1,
type=int, help="Number of split-k partitions. (default 1)")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
# parser.add_argument('-h', '--help', action="store_true",
# help="print help information")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
operation = GemmOperationUniversal(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor
)
if args.print_cuda:
print(operation.rt_module.emit())
operations = [operation, ]
if args.gemm_mode == "GemmSplitKParallel":
reduction_operation = ReductionOperation(
shape=cutlass.MatrixCoord(4, 32 * C.alignment),
C=C, element_accumulator=element_acc,
element_compute=element_epilogue,
count=C.alignment
)
operations.append(reduction_operation)
pycutlass.compiler.add_module(operations)
# User-provide inputs
problem_size = cutlass.gemm.GemmCoord(
args.problem_size[0], args.problem_size[1], args.problem_size[2])
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(bfloat16)
else:
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.k(),)).astype(getattr(np, args.element_a))
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(bfloat16)
else:
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
* problem_size.n(),)).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.element_c == "bfloat16":
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(bfloat16)
else:
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.n(),)).astype(getattr(np, args.element_c))
tensor_D = np.ones_like(tensor_C)
arguments = GemmArguments(
operation=operation, problem_size=problem_size,
A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta),
gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode),
split_k_slices=args.split_k_slices
)
if args.gemm_mode == "GemmSplitKParallel":
reduction_arguments = ReductionArguments(
operation=reduction_operation,
problem_size=[problem_size.m(), problem_size.n()],
partitions=args.split_k_slices, workspace=arguments.ptr_D,
destination=tensor_D, source=tensor_C,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
)
operation.run(arguments)
if args.gemm_mode == "GemmSplitKParallel":
reduction_operation.run(reduction_arguments)
reduction_arguments.sync()
else:
arguments.sync()
# run the host reference module
reference = ReferenceModule(A, B, C)
tensor_D_ref = reference.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta)
assert np.array_equal(tensor_D, tensor_D_ref)
print("Passed.")

View File

@ -0,0 +1,248 @@
################################################################################
#
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import pycutlass
from pycutlass import *
import csv
import argparse
# parse the arguments
parser = argparse.ArgumentParser(
description="Launch CUTLASS GEMM Grouped kernels from python")
# Operation description
# math instruction description
parser.add_argument("-i", "--instruction_shape",
default=[1, 1, 1], nargs=3, type=int,
help="This option describes the size of MMA op")
parser.add_argument("-ta", "--element_a", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor A')
parser.add_argument("-tb", "--element_b", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor B')
parser.add_argument("-tc", "--element_c", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of elements in input tensor C and output tensor D')
parser.add_argument("-tacc", "--element_acc", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'],
help='Data type of accumulator')
parser.add_argument('-m', "--math", default="multiply_add",
type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction")
parser.add_argument('-op', "--opcode", default="simt", type=str,
choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \
cores (TensorOp) or regular SIMT cores (Simt) on GPU SM')
# tile description
parser.add_argument("-b", "--threadblock_shape",
default=[128, 128, 8], nargs=3, type=int,
help="This option describes the tile size a thread block with compute")
parser.add_argument("-s", "--stages", default=4,
type=int, help="Number of pipelines you want to use")
parser.add_argument("-w", "--warp_count", default=[
4, 2, 1], nargs=3, type=int,
help="This option describes the number of warps along M, N, and K of the threadblock")
parser.add_argument("-cc", "--compute_capability", default=80,
type=int, help="This option describes CUDA SM architecture number")
# A
parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor A")
parser.add_argument('-aa', '--alignment_a', default=1,
type=int, help="Memory alignment of input tensor A")
# B
parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor B")
parser.add_argument('-ab', '--alignment_b', default=1,
type=int, help="Memory alignment of input tensor B")
# C
parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[
"RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"],
help="Memory layout of input tensor C and output tensor D")
parser.add_argument('-ac', '--alignment_c', default=1,
type=int, help="Memory alignment of input tensor C and output tensor D")
# epilogue
parser.add_argument("-te", "--element_epilogue", default="float32", type=str,
choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype')
parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination",
type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'],
help="This option describes the epilogue part of the kernel")
# swizzling
parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[
"IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"],
help="This option describes how thread blocks are scheduled on GPU")
# precompute mode
parser.add_argument("-pm", "--precompute_mode",
default="Device", type=str, choices=["Host", "Device"],
help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)")
# arguments
parser.add_argument("-p", "--problem_size_dir", type=str,
help="path to the csv file contains the problem sizes")
parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha")
parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta")
parser.add_argument('--print_cuda', action="store_true",
help="print the underlying CUDA kernel")
try:
args = parser.parse_args()
except:
sys.exit(0)
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
element_a = getattr(cutlass, args.element_a)
element_b = getattr(cutlass, args.element_b)
element_c = getattr(cutlass, args.element_c)
element_acc = getattr(cutlass, args.element_acc)
math_operation = getattr(MathOperation, args.math)
opclass = getattr(cutlass.OpClass, args.opcode)
math_inst = MathInstruction(
args.instruction_shape, element_a, element_b,
element_acc, opclass, math_operation
)
tile_description = TileDescription(
args.threadblock_shape, args.stages, args.warp_count,
math_inst, args.compute_capability, args.compute_capability
)
layout_a = getattr(cutlass, args.layout_a)
layout_b = getattr(cutlass, args.layout_b)
layout_c = getattr(cutlass, args.layout_c)
A = TensorDescription(
element_a, layout_a, args.alignment_a
)
B = TensorDescription(
element_b, layout_b, args.alignment_b
)
C = TensorDescription(
element_c, layout_c, args.alignment_c
)
element_epilogue = getattr(cutlass, args.element_epilogue)
epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor)
swizzling_functor = getattr(cutlass, args.swizzling_functor)
precompute_mode = getattr(SchedulerMode, args.precompute_mode)
operation = GemmOperationGrouped(
arch=args.compute_capability, tile_description=tile_description,
A=A, B=B, C=C, element_epilogue=element_epilogue,
epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor,
precompute_mode=precompute_mode
)
if args.print_cuda:
print(operation.rt_module.emit())
pycutlass.compiler.add_module([operation, ])
reference_module = ReferenceModule(A, B, C)
# get problems
problem_sizes = []
with open(args.problem_size_dir) as csv_file:
reader = csv.reader(csv_file)
for row in reader:
problem_sizes.append(
cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2]))
)
problem_count = len(problem_sizes)
tensor_As = []
tensor_Bs = []
tensor_Cs = []
tensor_Ds = []
problem_sizes_coord = []
tensor_D_refs = []
for problem_size in problem_sizes:
if args.element_a != "int8":
if args.element_a == "bfloat16":
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(bfloat16)
else:
tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.k(),))).astype(getattr(np, args.element_a))
else:
tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.k(),)).astype(getattr(np, args.element_a))
if args.element_b != "int8":
if args.element_b == "bfloat16":
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(bfloat16)
else:
tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k()
* problem_size.n(),))).astype(getattr(np, args.element_b))
else:
tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k()
* problem_size.n(),)).astype(getattr(np, args.element_b))
if args.element_c != "int8":
if args.element_c == "bfloat16":
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(bfloat16)
else:
tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m()
* problem_size.n(),))).astype(getattr(np, args.element_c))
else:
tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m()
* problem_size.n(),)).astype(getattr(np, args.element_c))
tensor_D = np.zeros_like(tensor_C)
tensor_As.append(tensor_A)
tensor_Bs.append(tensor_B)
tensor_Cs.append(tensor_C)
tensor_Ds.append(tensor_D)
tensor_D_refs.append(reference_module.run(
tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta))
problem_sizes_coord.append(problem_size)
arguments = GemmGroupedArguments(
operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds,
output_op=LinearCombinationFunctorArguments(args.alpha, args.beta)
)
operation.run(arguments)
arguments.sync()
for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs):
assert np.array_equal(tensor_d, tensor_d_ref)
print("Passed.")

View File

@ -0,0 +1,3 @@
128,128,128
128,128,256
512,128,384
1 128 128 128
2 128 128 256
3 512 128 384

View File

@ -1,169 +0,0 @@
# System modules
import numpy as np
import os.path
import sys
import ctypes
# CUDA Python modules
from cuda import cuda
from cuda import nvrtc
# CUTLASS modules
import library
import manifest as cutlass_manifest
import generator
import rt
#
# Construct an SGEMM
#
manifest = cutlass_manifest.Manifest()
generator.GenerateSM50_Simt(manifest, "11.5.0")
#
# Construct a GEMM operation
#
operation = manifest.operations_by_name['cutlass_simt_sgemm_128x128_8x2_nt_align1']
#
# Construct a runtime GEMM operation
#
gemm = rt.Gemm(operation)
#
# Initialize context
#
err, = cuda.cuInit(0)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
err, device = cuda.cuDeviceGet(0)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
err, context = cuda.cuCtxCreate(0, device)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
#
# Construct a module
#
architectures = [80,]
include_paths = [
'../../include',
'../../tools/util/include',
]
compilation_options = rt.CompilationOptions(architectures, include_paths)
module = rt.Module('module.cu', [gemm], compilation_options)
#
# Setup a workspace
#
M, N, K = (128, 128, 128)
tensor_A = np.ndarray(M * K, dtype=np.float32)
tensor_B = np.ndarray(N * K, dtype=np.float32)
tensor_C = np.ndarray(M * N, dtype=np.float32)
tensor_D = np.ndarray(M * N, dtype=np.float32)
err, tensor_A_d = cuda.cuMemAlloc(tensor_A.size * tensor_A.itemsize)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
err, tensor_B_d = cuda.cuMemAlloc(tensor_B.size * tensor_B.itemsize)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
err, tensor_C_d = cuda.cuMemAlloc(tensor_C.size * tensor_C.itemsize)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
err, tensor_D_d = cuda.cuMemAlloc(tensor_D.size * tensor_D.itemsize)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
err, stream = cuda.cuStreamCreate(0)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
tensors = [
(tensor_A_d, tensor_A),
(tensor_B_d, tensor_B),
(tensor_C_d, tensor_C),
(tensor_D_d, tensor_D)
]
for tensor_device, tensor_host in tensors:
bytes = tensor_host.size * tensor_host.itemsize
print("Tensor has dimensions: %s (%d bytes)" % (str(tensor_host.size), tensor_host.itemsize))
err, = cuda.cuMemcpyHtoDAsync(tensor_device, tensor_host, bytes, stream)
print("updating tensor in device memory ", hex(int(tensor_device)))
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('CUDA Error %s' % str(err))
#
# Initialize a host buffer
#
arguments = rt.GemmArguments()
arguments.problem_size = rt.GemmCoord(M, N, K)
arguments.A = rt.TensorRef(tensor_A_d, M)
arguments.B = rt.TensorRef(tensor_B_d, N)
arguments.C = rt.TensorRef(tensor_C_d, M)
arguments.D = rt.TensorRef(tensor_D_d, M)
host_workspace = bytearray(gemm.get_host_workspace_size(arguments))
device_workspace = None
launch_config = gemm.plan(arguments)
byte_count = gemm.initialize(host_workspace, device_workspace, launch_config, arguments)
#
# Launch the kernel
#
err = gemm.run(host_workspace, device_workspace, launch_config)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('CUDA Error %s' % str(err))
#
# Verify results
#
err, = cuda.cuStreamSynchronize(stream)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
#
# Debug reporting of byte array contents
#
def PrintBytearray(host_workspace):
uint_str = None
prefix = None
print("uint32_t host_workspace[] = {")
for idx, byte in enumerate(host_workspace):
if not (idx % 4):
if uint_str is not None:
print(prefix, uint_str, ",")
prefix = "/* offset: %d B */ 0x" % idx
uint_str = ""
uint_str = "{:02x}".format(byte) + uint_str
print("};")

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
41_multi_head_attention
fused_multihead_attention.cu
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,626 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines the FusedMultiHeadAttention Class
The class contains the following:
1) GEMM0 with epilogue fusion,
2) GEMM1 with mainloop fusion, and
3) A lightweight full softmax reduction kernel.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <cmath>
#include <iostream>
#include <vector>
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h"
#include "cutlass/reduction/kernel/reduce_softmax_final.h"
#include "gemm_grouped_with_softmax_visitor.h"
namespace cutlass {
template <
typename ElementQ_,
typename LayoutQ_,
typename ElementK_,
typename LayoutK_,
typename ElementP_,
typename LayoutP_,
typename ElementCompute_,
typename OperatorClass_,
typename ArchTag_,
typename ThreadblockShape0_,
typename ThreadblockShape1_,
typename WarpShape0_,
typename WarpShape1_,
typename InstructionShape_,
int kStages0_,
int kStages1_,
bool UseMasking_ = false,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
int Alignment = 128 / cutlass::sizeof_bits<ElementQ_>::value,
typename ElementSoftmax_ = ElementP_
>
class FusedMultiHeadAttention {
public:
using ElementQ = ElementQ_;
using ElementK = ElementK_;
using ElementP = ElementP_;
using ElementV = ElementK;
using ElementOutput = ElementP;
using ElementAccumulator = ElementCompute_;
using LayoutQ = LayoutQ_;
using LayoutK = LayoutK_;
using LayoutP = LayoutP_;
using LayoutV = LayoutK;
using LayoutO = LayoutP;
using ElementNorm = cutlass::half_t;
using ElementSum = cutlass::half_t;
using ElementSoftmaxCompute = float;
using LayoutNorm = cutlass::layout::RowMajor;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape0 = ThreadblockShape0_;
using WarpShape0 = WarpShape0_;
using ThreadblockShape1 = ThreadblockShape1_;
using WarpShape1 = WarpShape1_;
static int const Stages0 = kStages0_;
static int const Stages1 = kStages1_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>;
using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration<
OperatorClass, ArchTag, ElementQ, ElementK, ElementP,
ElementAccumulator>::Operator;
static bool const kInternalTranspose = cutlass::platform::is_same<LayoutP, cutlass::layout::ColumnMajor>::value;
static bool const kUseMasking = UseMasking_;
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_;
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_;
using MapArguments = cutlass::gemm::kernel::detail::MapArguments<
ElementQ,
LayoutQ,
cutlass::ComplexTransform::kNone,
8,
ElementK,
LayoutK,
cutlass::ComplexTransform::kNone,
8,
LayoutP,
kInternalTranspose
>;
using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm<
typename MapArguments::ElementA,
typename MapArguments::LayoutA,
MapArguments::kAlignmentA,
typename MapArguments::ElementB,
typename MapArguments::LayoutB,
MapArguments::kAlignmentB,
ElementP,
typename MapArguments::LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape0,
WarpShape0,
InstructionShape,
EpilogueOutputOp0,
ThreadblockSwizzle,
Stages0,
true,
Operator,
cutlass::gemm::SharedMemoryClearOption::kNone
>::GemmKernel;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax<
ThreadblockShape0,
DefaultGemmKernel::kThreadCount,
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
typename EpilogueOutputOp0::ElementCompute,
ElementNorm,
ElementSum,
ElementSoftmaxCompute,
EpilogueOutputOp0,
kUseMasking
>;
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
EpilogueVisitor,
typename DefaultGemmKernel::Epilogue
>::Epilogue;
using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor<
typename DefaultGemmKernel::Mma,
Epilogue,
ThreadblockSwizzle,
kGroupScheduleMode0,
kInternalTranspose,
kUseMasking
>;
using GemmGrouped0 = cutlass::gemm::device::GemmGrouped<GemmKernel0>;
using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
ElementNorm,
ElementSum,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape,
true
>;
using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion<
ElementP,
LayoutP,
cutlass::ComplexTransform::kNone,
128 / cutlass::sizeof_bits<ElementQ>::value,
ElementV,
LayoutV,
cutlass::ComplexTransform::kNone,
128 / cutlass::sizeof_bits<ElementK>::value,
ElementNorm,
LayoutNorm,
ElementOutput,
LayoutO,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape1,
WarpShape1,
InstructionShape,
EpilogueOutputOp1,
ThreadblockSwizzle,
Stages1,
kGroupScheduleMode1
>::GemmKernel;
using GemmGrouped1 = cutlass::gemm::device::GemmGrouped<GemmKernel1>;
public:
/// Arguments class
struct Arguments {
cutlass::gemm::GemmCoord *problem_sizes0;
cutlass::gemm::GemmCoord *problem_sizes0_real;
cutlass::gemm::GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementP ** ptr_V;
ElementP ** ptr_O;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
ElementP *block_P;
ElementNorm *block_Norm;
ElementSum *block_Sum;
int64_t *offset_P;
int64_t *offset_Norm_Device;
int64_t *offset_Sum_Device;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldp;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutP::Stride::LongIndex *ldo;
cutlass::gemm::GemmCoord *problem_sizes0_host;
cutlass::gemm::GemmCoord *problem_sizes1_host;
ElementAccumulator alpha0;
ElementAccumulator alpha1;
ElementAccumulator beta;
int head_number;
int batch_size;
int seq_length;
typename ApplyFinalReductionDevice::Arguments reduction;
//
// Methods
//
Arguments():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
block_P(nullptr),
block_Norm(nullptr),
block_Sum(nullptr),
offset_P(nullptr),
offset_Norm_Device(nullptr),
offset_Sum_Device(nullptr),
ldq(nullptr),
ldk(nullptr),
ldp(nullptr),
ldv(nullptr),
ldo(nullptr),
head_number(0),
batch_size(0),
seq_length(0)
{
}
Arguments(
cutlass::gemm::GemmCoord *problem_sizes0,
cutlass::gemm::GemmCoord *problem_sizes1,
int problem_count,
int threadblock_count,
ElementQ ** ptr_Q,
ElementK ** ptr_K,
ElementP ** ptr_P,
ElementP ** ptr_V,
ElementP ** ptr_O,
ElementNorm **ptr_Max,
ElementSum **ptr_Sum,
ElementP *block_P,
ElementNorm *block_Norm,
ElementSum *block_Sum,
int64_t *offset_P,
int64_t *offset_Norm_Device,
int64_t *offset_Sum_Device,
typename LayoutQ::Stride::LongIndex *ldq,
typename LayoutK::Stride::LongIndex *ldk,
typename LayoutP::Stride::LongIndex *ldp,
typename LayoutP::Stride::LongIndex *ldv,
typename LayoutP::Stride::LongIndex *ldo,
ElementAccumulator alpha0,
ElementAccumulator alpha1,
ElementAccumulator beta,
int head_number,
int batch_size,
int seq_length,
cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr,
cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr,
cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
threadblock_count(threadblock_count),
ptr_Q(ptr_Q),
ptr_K(ptr_K),
ptr_P(ptr_P),
ptr_V(ptr_V),
ptr_O(ptr_O),
ptr_Max(ptr_Max),
ptr_Sum(ptr_Sum),
block_P(block_P),
block_Norm(block_Norm),
block_Sum(block_Sum),
offset_P(offset_P),
offset_Norm_Device(offset_Norm_Device),
offset_Sum_Device(offset_Sum_Device),
ldq(ldq),
ldk(ldk),
ldp(ldp),
ldv(ldv),
ldo(ldo),
alpha0(alpha0),
alpha1(alpha1),
beta(beta),
head_number(head_number),
batch_size(batch_size),
seq_length(seq_length),
problem_sizes0_host(problem_sizes0_host),
problem_sizes1_host(problem_sizes1_host),
problem_sizes0_real(problem_sizes0_real),
reduction(
problem_sizes0,
block_Norm,
block_Sum,
offset_Norm_Device,
offset_Sum_Device
)
{
}
};
struct Params {
cutlass::gemm::GemmCoord *problem_sizes0;
cutlass::gemm::GemmCoord *problem_sizes0_real;
cutlass::gemm::GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementP ** ptr_V;
ElementP ** ptr_O;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
ElementP *block_P;
ElementNorm *block_Norm;
ElementSum *block_Sum;
int64_t *offset_P;
int64_t *offset_Norm_Device;
int64_t *offset_Sum_Device;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldp;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutP::Stride::LongIndex *ldo;
cutlass::gemm::GemmCoord *problem_sizes0_host;
cutlass::gemm::GemmCoord *problem_sizes1_host;
ElementAccumulator alpha0;
ElementAccumulator alpha1;
ElementAccumulator beta;
int head_number;
int batch_size;
int seq_length;
typename ApplyFinalReductionDevice::Params reduction;
Params():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
block_P(nullptr),
block_Norm(nullptr),
block_Sum(nullptr),
offset_P(nullptr),
offset_Norm_Device(nullptr),
offset_Sum_Device(nullptr),
ldq(nullptr),
ldk(nullptr),
ldp(nullptr),
ldv(nullptr),
ldo(nullptr),
problem_sizes0(nullptr),
problem_sizes1(nullptr),
problem_sizes0_real(nullptr),
head_number(0),
batch_size(0),
seq_length(0)
{
}
Params(Arguments const &args, void *workspace = nullptr):
problem_sizes0(args.problem_sizes0),
problem_sizes1(args.problem_sizes1),
problem_count(args.problem_count),
threadblock_count(args.threadblock_count),
ptr_Q(args.ptr_Q),
ptr_K(args.ptr_K),
ptr_P(args.ptr_P),
ptr_V(args.ptr_V),
ptr_O(args.ptr_O),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
block_P(args.block_P),
block_Norm(args.block_Norm),
block_Sum(args.block_Sum),
offset_P(args.offset_P),
offset_Norm_Device(args.offset_Norm_Device),
offset_Sum_Device(args.offset_Sum_Device),
ldq(args.ldq),
ldk(args.ldk),
ldp(args.ldp),
ldv(args.ldv),
ldo(args.ldo),
problem_sizes0_host(args.problem_sizes0_host),
problem_sizes1_host(args.problem_sizes1_host),
problem_sizes0_real(args.problem_sizes0_real),
alpha0(args.alpha0),
alpha1(args.alpha1),
beta(args.beta),
head_number(args.head_number),
batch_size(args.batch_size),
seq_length(args.seq_length),
reduction(args.reduction)
{
}
};
private:
Params params_;
GemmGrouped0 gemm_grouped0;
GemmGrouped1 gemm_grouped1;
public:
/// Ctor
FusedMultiHeadAttention() {
}
/// Initialize
Status initialize(Arguments const &args,
void *workspace0 = nullptr,
void *workspace1 = nullptr) {
params_ = Params(args);
typename GemmGrouped0::Arguments args_gemm0(
params_.problem_sizes0,
params_.problem_count,
params_.threadblock_count,
params_.ptr_Q,
params_.ptr_K,
params_.ptr_P,
params_.ptr_P,
params_.ptr_Max,
params_.ptr_Sum,
params_.ldq,
params_.ldk,
params_.ldp,
params_.ldp,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments(
{
params_.alpha0,
params_.beta
}
),
params_.problem_sizes0_host,
params_.problem_sizes0_real
);
Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0);
typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta);
typename GemmGrouped1::Arguments args_gemm1(
params_.problem_sizes1,
params_.problem_count,
params_.threadblock_count,
epilogue_op1,
params_.ptr_P,
params_.ptr_V,
params_.ptr_O,
params_.ptr_O,
(void**)params_.ptr_Max,
(void**)params_.ptr_Sum,
params_.ldp,
params_.ldv,
params_.ldo,
params_.ldo,
params_.problem_sizes1_host
);
Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1);
if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) {
return cutlass::Status::kSuccess;
}else{
if (result0 != cutlass::Status::kSuccess) {
return result0;
}else{
return result1;
}
}
}
/// Run
Status run(cudaStream_t stream = nullptr) {
Status result = gemm_grouped0.run();
cudaError_t error_info;
if (result != cutlass::Status::kSuccess) {
return cutlass::Status::kErrorInternal;
}
int thread_per_block = 1024;
dim3 final_reduction_grid(params_.head_number * params_.batch_size);
dim3 final_reduction_block(thread_per_block);
cutlass::Kernel<ApplyFinalReductionDevice><<<
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream
>>>(params_.reduction);
error_info = cudaGetLastError();
if (error_info != cudaSuccess) {
return cutlass::Status::kErrorInternal;
}
result = gemm_grouped1.run();
if (result != cutlass::Status::kSuccess) {
return cutlass::Status::kErrorInternal;
}
return cutlass::Status::kSuccess;
}
/// Function call operator
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
};
}

View File

@ -0,0 +1,522 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped GEMM kernel with epilogue visitor customized for softmax
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed_ = false,
bool UseMask_ = false
>
struct GemmGroupedWithEpilogueVistor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
using EpilogueVisitor = typename Epilogue::Visitor;
using EpilogueOutputOp = typename EpilogueVisitor::ElementwiseFunctor;
static bool const kTransposed = Transposed_;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<
typename Mma::IteratorA::Element,
typename Mma::IteratorA::Layout,
Mma::kTransformA,
Mma::IteratorA::AccessType::kElements,
typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout,
Mma::kTransformB,
Mma::IteratorB::AccessType::kElements,
typename Mma::LayoutC,
kTransposed
>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename MapArguments::LayoutC;
using ElementNorm = typename EpilogueVisitor::ElementNorm;
using ElementSum = typename EpilogueVisitor::ElementSum;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor = GemmGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
kThreadCount,
kThreadCount,
kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord *problem_sizes;
// when using mask, real problem sizes may not be aligned
// then we need to mask out unpadded elements in softmax
GemmCoord *problem_sizes_real;
int problem_count;
int threadblock_count;
ElementA ** ptr_A;
ElementB ** ptr_B;
ElementC ** ptr_C;
ElementC ** ptr_D;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
typename LayoutA::Stride::LongIndex *lda;
typename LayoutB::Stride::LongIndex *ldb;
typename LayoutC::Stride::LongIndex *ldc;
typename LayoutC::Stride::LongIndex *ldd;
typename EpilogueVisitor::Arguments epilogue_visitor;
// Only used by device-level operator
GemmCoord *host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments():
problem_count(0),
threadblock_count(0),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
lda(nullptr),
ldb(nullptr),
ldc(nullptr),
ldd(nullptr),
host_problem_sizes(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord *problem_sizes,
int problem_count,
int threadblock_count,
ElementA ** ptr_A,
ElementB ** ptr_B,
ElementC ** ptr_C,
ElementC ** ptr_D,
ElementNorm **ptr_Max,
ElementSum **ptr_Sum,
typename LayoutA::Stride::LongIndex *lda,
typename LayoutB::Stride::LongIndex *ldb,
typename LayoutC::Stride::LongIndex *ldc,
typename LayoutC::Stride::LongIndex *ldd,
typename EpilogueVisitor::Arguments epilogue_visitor_,
GemmCoord *host_problem_sizes=nullptr,
GemmCoord *problem_sizes_real=nullptr
):
problem_sizes(problem_sizes),
problem_count(problem_count),
threadblock_count(threadblock_count),
ptr_A(ptr_A),
ptr_B(ptr_B),
ptr_C(ptr_C),
ptr_D(ptr_D),
ptr_Max(ptr_Max),
ptr_Sum(ptr_Sum),
lda(lda),
ldb(ldb),
ldc(ldc),
ldd(ldd),
epilogue_visitor(epilogue_visitor_),
host_problem_sizes(host_problem_sizes),
problem_sizes_real(problem_sizes_real)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
typename ProblemVisitor::Params problem_visitor;
GemmCoord *problem_sizes_real;
int threadblock_count;
ElementA ** ptr_A;
ElementB ** ptr_B;
ElementC ** ptr_C;
ElementC ** ptr_D;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
typename LayoutA::Stride::LongIndex *lda;
typename LayoutB::Stride::LongIndex *ldb;
typename LayoutC::Stride::LongIndex *ldc;
typename LayoutC::Stride::LongIndex *ldd;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
lda(nullptr),
ldb(nullptr),
ldc(nullptr),
ldd(nullptr),
problem_sizes_real(problem_sizes_real)
{ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args, void *workspace = nullptr, int32_t tile_count = 0):
problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
ptr_A(args.ptr_A),
ptr_B(args.ptr_B),
ptr_C(args.ptr_C),
ptr_D(args.ptr_D),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
lda(args.lda),
ldb(args.ldb),
ldc(args.ldc),
ldd(args.ldd),
epilogue_visitor(args.epilogue_visitor),
problem_sizes_real(args.problem_sizes_real)
{
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr,
int32_t tile_count = -1) {
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
ptr_Max = args.ptr_Max;
ptr_Sum = args.ptr_Sum;
lda = args.lda;
ldb = args.ldb;
ldc = args.ldc;
ldd = args.ldd;
problem_sizes_real = args.problem_sizes_real;
}
};
/// Shared memory storage structure
struct SharedStorage {
union {
typename Mma::SharedStorage main_loop;
struct {
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
} kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmGroupedWithEpilogueVistor() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
static size_t get_extra_workspace_size(
Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Mma::LayoutC;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(
params.problem_visitor,
shared_storage.problem_visitor,
blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN,
0);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA *ptr_A = reinterpret_cast<ElementA *>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB *ptr_B = reinterpret_cast<ElementB *>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{
0,
threadblock_offset.n()
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A),
ptr_A,
{problem_size.m(), problem_size.k()},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
ptr_B,
{problem_size.k(), problem_size.n()},
thread_idx,
tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
ElementC *ptr_C = params.ptr_C[problem_idx];
ElementC *ptr_D = params.ptr_D[problem_idx];
ElementNorm *ptr_Max = params.ptr_Max[problem_idx];
ElementSum *ptr_Sum = params.ptr_Sum[problem_idx];
LayoutC layout_C(params.ldc[problem_idx]);
LayoutC layout_D(params.ldd[problem_idx]);
int column_offset = (threadblock_offset.n() / ThreadblockShape::kN) * problem_size.m();
typename EpilogueVisitor::OutputTileIterator::Params params_C(layout_C);
typename EpilogueVisitor::OutputTileIterator::Params params_D(layout_D);
//
// Construct the epilogue visitor
//
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.kernel.epilogue.visitor,
problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
params_C,
params_D,
ptr_C,
ptr_D,
ptr_Max,
ptr_Sum,
threadblock_offset.mn(),
column_offset,
params.problem_sizes_real[problem_idx].mn()
);
// Construct the epilogue
Epilogue epilogue(
shared_storage.kernel.epilogue.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor
epilogue(epilogue_visitor, accumulators);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -116,6 +116,10 @@ foreach(EXAMPLE
34_transposed_conv2d
35_gemm_softmax
36_gather_scatter_fusion
37_gemm_layernorm_gemm_fusion
38_syr2k_grouped
39_gemm_permute
41_multi_head_attention
)
add_subdirectory(${EXAMPLE})

View File

@ -98,6 +98,9 @@ template <
bool IsHermitianData = false>
struct cp_async_diag;
static const uint32_t OOB_NAN_F16 = 0x7eff;
static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16);
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
@ -190,8 +193,8 @@ struct cp_async_nan<16, CacheOperation::Always> {
cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) {
#if CUDA_CP_ASYNC_ACTIVATED
static __constant__ uint4 OOB_NAN_F16x8 = {0x7eff7eff, 0x7eff7eff,
0x7eff7eff, 0x7eff7eff};
static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2,
OOB_NAN_F16x2, OOB_NAN_F16x2};
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
@ -305,7 +308,6 @@ struct cp_async_diag <Element_, true> {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
@ -386,6 +388,47 @@ struct cp_async_zfill<SizeInBytes, CacheOperation::Global> {
}
};
/// Partial specialization
template <>
struct cp_async_nan<16, CacheOperation::Global> {
static int const kSizeInBytes = 16;
/// Copy with nan fill
CUTLASS_DEVICE
cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) {
#if CUDA_CP_ASYNC_ACTIVATED
static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2,
OOB_NAN_F16x2, OOB_NAN_F16x2};
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
#if CUTLASS_ENABLE_L2_PREFETCH
" @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n"
#else
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
#endif
" @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n"
"}\n"
:
: "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr),
"n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z),
"r"(OOB_NAN_F16x8.w));
#else
CUTLASS_UNUSED(smem_ptr);
CUTLASS_UNUSED(global_ptr);
CUTLASS_UNUSED(pred_guard);
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.

View File

@ -126,6 +126,10 @@ struct Mma<
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -188,6 +192,10 @@ struct Mma<
);
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -251,6 +259,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -308,6 +320,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -366,6 +382,10 @@ struct Mma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -423,6 +443,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -486,6 +510,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -543,6 +571,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -600,6 +632,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -657,6 +693,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -711,7 +751,6 @@ struct Mma<
unsigned const & A = reinterpret_cast<unsigned const &>(a);
unsigned const & B = reinterpret_cast<unsigned const &>(b);
int const *C = reinterpret_cast<int const *>(&c);
int *D = reinterpret_cast<int *>(&d);
@ -720,6 +759,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -777,6 +820,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -834,6 +881,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -891,6 +942,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -954,6 +1009,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1011,6 +1070,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1068,6 +1131,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1125,6 +1192,10 @@ struct Mma<
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1210,11 +1281,19 @@ struct Mma<
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions.
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif

View File

@ -587,6 +587,10 @@ struct Mma<
"r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1571,6 +1575,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1631,6 +1639,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1691,6 +1703,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1751,6 +1767,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1818,6 +1838,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1878,6 +1902,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1938,6 +1966,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1998,6 +2030,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -2059,6 +2095,10 @@ struct Mma<
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -2126,6 +2166,10 @@ struct Mma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED)

View File

@ -141,6 +141,10 @@ struct SparseMma<
assert(0);
}
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -224,6 +228,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -296,6 +304,10 @@ struct SparseMma<gemm::GemmShape<16, 8, 32>, 32, bfloat16_t, layout::RowMajor,
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -368,6 +380,10 @@ struct SparseMma<gemm::GemmShape<16, 8, 16>, 32, tfloat32_t, layout::RowMajor,
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -449,6 +465,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -524,6 +544,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -599,6 +623,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -674,6 +702,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -755,6 +787,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -830,6 +866,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -905,6 +945,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -980,6 +1024,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1061,6 +1109,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1136,6 +1188,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1211,6 +1267,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1286,6 +1346,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1367,6 +1431,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1442,6 +1510,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1517,6 +1589,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
@ -1592,6 +1668,10 @@ struct SparseMma<
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}

View File

@ -127,8 +127,8 @@ template <typename T>
class complex
{
public:
/// Type alias for scalar type
using value_type = T;
/// Type alias for scalar type
using value_type = T;
private:
//

View File

@ -268,7 +268,7 @@ public:
CUTLASS_HOST_DEVICE
int64_t filter_size() const {
return (K * R * S * C);
return (K * R * S * C / groups);
}
/// Returns output size in number of elements
@ -362,61 +362,128 @@ int implicit_gemm_k_iterations(
Operator conv_operator,
int threadblock_K,
Conv2dProblemSize const &problem_size,
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
GroupMode group_mode = GroupMode::kNone,
int threadblock_N = 0) {
int iterations = 0;
if (algorithm == IteratorAlgorithm::kFixedChannels) {
if (group_mode == GroupMode::kNone) {
int positions_per_iteration = threadblock_K / problem_size.C;
switch (conv_operator) {
case Operator::kFprop:
iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration;
break;
if (algorithm == IteratorAlgorithm::kFixedChannels) {
default:
break;
int positions_per_iteration = threadblock_K / problem_size.C;
switch (conv_operator) {
case Operator::kFprop:
iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration;
break;
default:
break;
}
}
}
else if (algorithm == IteratorAlgorithm::kFewChannels) {
else if (algorithm == IteratorAlgorithm::kFewChannels) {
switch (conv_operator) {
case Operator::kFprop:
iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K;
break;
switch (conv_operator) {
case Operator::kFprop:
iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K;
break;
default:
break;
default:
break;
}
}
}
else {
int elements_per_split_k_slice = 0;
else {
int elements_per_split_k_slice = 0;
switch (conv_operator) {
case Operator::kFprop:
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
switch (conv_operator) {
case Operator::kFprop:
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kDgrad:
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kDgrad:
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kWgrad:
elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
break;
case Operator::kWgrad:
elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
break;
default:
break;
default:
break;
}
}
} else if (group_mode == GroupMode::kDepthwise) {
int channels_per_cta = threadblock_N;
if (algorithm == IteratorAlgorithm::kAnalytic) {
switch (conv_operator) {
case Operator::kFprop:
iterations = problem_size.R * problem_size.S *
((channels_per_cta + threadblock_K - 1) / threadblock_K);
break;
default:
break;
}
}
} else { // Group conv
int channels_per_group = problem_size.C / problem_size.groups;
int k_per_group = problem_size.K / problem_size.groups;
if (algorithm == IteratorAlgorithm::kAnalytic) {
switch (conv_operator) {
case Operator::kFprop:
iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
// In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups
if (problem_size.groups != 1) {
if (k_per_group < threadblock_N) {
iterations *= threadblock_N / k_per_group;
}
}
break;
default:
break;
}
}
}
return iterations;
}
CUTLASS_HOST_DEVICE
int implicit_gemm_k_iterations_per_channel(
Operator conv_operator,
int threadblock_K,
Conv2dProblemSize const &problem_size,
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
int iterations = 0; //0 means not applicable
if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) {
switch (conv_operator) {
case Operator::kFprop:
iterations = problem_size.R * problem_size.S;
break;
case Operator::kDgrad:
iterations = problem_size.R * problem_size.S;
break;
default:
break;
}
}
return iterations;
}
////////////////////////////////////////////////////////////////////////////////
// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
////////////////////////////////////////////////////////////////////////////////
@ -537,12 +604,12 @@ void strided_dgrad_starting_coords(
// function locals for remainder by fast divmod
int pad_h_rem_, pad_w_rem_;
// start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
stride_h_divmod.divmod(start_h, r_);
//start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
stride_w_divmod.divmod(start_w, s_);

View File

@ -339,29 +339,46 @@ int implicit_gemm_k_iterations(
Operator conv_operator,
int threadblock_K,
Conv3dProblemSize const &problem_size,
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) {
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
GroupMode group_mode = GroupMode::kNone,
int threadblock_N = 0) {
int iterations = 0;
int elements_per_split_k_slice = 0;
if (group_mode == GroupMode::kNone) {
switch (conv_operator) {
case Operator::kFprop:
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kDgrad:
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kWgrad:
elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
break;
default:
break;
}
} else if (group_mode == GroupMode::kDepthwise) {
int channels_per_cta = threadblock_N;
switch (conv_operator) {
case Operator::kFprop:
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kDgrad:
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kWgrad:
elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
break;
default:
break;
if (algorithm == IteratorAlgorithm::kAnalytic) {
switch (conv_operator) {
case Operator::kFprop:
iterations = problem_size.T * problem_size.R * problem_size.S *
((channels_per_cta + threadblock_K - 1) / threadblock_K);
break;
default:
break;
}
}
}
return iterations;

View File

@ -117,6 +117,14 @@ enum class SplitKMode {
kParallel
};
/// Identifies group mode
enum class GroupMode {
kNone,
kSingleGroup, ///< One CTA calculates one group or less
kMultipleGroup, ///< One CTA calculates multiple groups
kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups)
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace conv

View File

@ -78,6 +78,7 @@ public:
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator;
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm;
static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport;
static cutlass::conv::GroupMode const kGroupMode = ImplicitGemmKernel::kGroupMode;
static int const kWarpCount =
(ThreadblockShape::kM / WarpShape::kM) *
@ -111,6 +112,34 @@ public:
return status;
}
// check group conv constraint
if (args.problem_size.groups != 1) {
if (kGroupMode == conv::GroupMode::kNone) {
return Status::kErrorInvalidProblem;
}
// C and K should be multiple of groups
if (args.problem_size.K % args.problem_size.groups ||
args.problem_size.C % args.problem_size.groups) {
return Status::kErrorInvalidProblem;
}
// split-k is not supported
if (args.problem_size.split_k_slices != 1) {
return Status::kErrorInvalidProblem;
}
int k_per_group = args.problem_size.K / args.problem_size.groups;
// k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group
if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) {
return Status::kErrorInvalidProblem;
}
// ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups
if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) {
return Status::kErrorInvalidProblem;
}
}
static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
if (kConvolutionalOperator == conv::Operator::kFprop) {
if (args.problem_size.K % kAlignmentC)

View File

@ -45,8 +45,8 @@
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h"
#include "cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h"
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h"
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -161,7 +161,7 @@ struct DefaultConv2dFpropFusion <
LayoutScaleBias>;
using SmemIteratorScaleBias =
cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator<
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
LayoutScaleBias>;
@ -172,7 +172,7 @@ struct DefaultConv2dFpropFusion <
static int const kThreadCount = 32;
// Warp-level iterators to load scale and bias vectors
using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias<
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
@ -296,7 +296,7 @@ struct DefaultConv2dFpropFusion <
LayoutScaleBias>;
using SmemIteratorScaleBias =
cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator<
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
LayoutScaleBias>;
@ -307,7 +307,7 @@ struct DefaultConv2dFpropFusion <
static int const kThreadCount = 32;
// Warp-level iterators to load scale and bias vectors
using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias<
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,

View File

@ -0,0 +1,222 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
matrix multiply-add with the appropriate threadblock-scoped epilogue.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_conv2d.h"
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h"
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dGroupFpro
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename OperatorClass,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::GroupMode GroupMode,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
> struct DefaultConv2dGroupFprop;
/////////////////////////////////////////////////////////////////////////////////////////////////
// OpClassTensorOp convolutions
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage
/// pipeline.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::GroupMode GroupMode,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dGroupFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
GroupMode,
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
Stages, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA, LayoutA,
ThreadMapA,
AccessTypeA,
GroupMode
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB, LayoutB,
ThreadMapB,
AccessTypeB,
GroupMode
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
CacheOpB,
MmaPolicy,
Stages
>;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape,
WarpMmaTensorOp,
kPartitionsK,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv2dProblemSize,
GroupMode
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,360 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution
definitions that combine threadblock-scoped matrix multiply-add with the
appropriate threadblock-scoped epilogue.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_conv2d.h"
#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h"
#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h"
#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h"
#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h"
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for fused batch norm and Conv3dFprop
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementScaleBias,
typename LayoutScaleBias,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename OperatorClass,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized,
conv::StrideSupport StrideSupport = StrideSupport::kStrided
> struct DefaultConv3dFpropFusion;
/////////////////////////////////////////////////////////////////////////////////////////////////
// OpClassTensorOp convolutions
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage
/// pipeline.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementScaleBias,
typename LayoutScaleBias,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
>
struct DefaultConv3dFpropFusion <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementScaleBias,
LayoutScaleBias,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kAnalytic
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
Stages, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
ThreadMapA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
ThreadMapB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
/// Define iterators over tiles from scale/bias vectors
using IteratorScaleBias =
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
LayoutScaleBias>;
using SmemIteratorScaleBias =
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
LayoutScaleBias>;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static int const kThreadCount = 32;
// Warp-level iterators to load scale and bias vectors
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
MmaCore::WarpCount::kK>;
// Define the Mma
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Global,
IteratorScaleBias,
SmemIteratorScaleBias,
arch::CacheOperation::Always,
MmaPolicy,
WarpIteratorScaleBias,
Stages
>;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape,
WarpMmaTensorOp,
1,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv3dProblemSize
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and
/// multistage pipeline.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementScaleBias,
typename LayoutScaleBias,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag
>
struct DefaultConv3dFpropFusion <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementScaleBias,
LayoutScaleBias,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
Stages, MathOperatorTag
>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
ThreadMapA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using IteratorB =
cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
ThreadMapB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
/// Define iterators over tiles from scale/bias vectors
using IteratorScaleBias =
cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator<
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
LayoutScaleBias>;
using SmemIteratorScaleBias =
cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator<
cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias,
LayoutScaleBias>;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static int const kThreadCount = 32;
// Warp-level iterators to load scale and bias vectors
using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator<
MatrixShape<WarpShape::kM, WarpShape::kK>, ElementScaleBias,
LayoutScaleBias, MatrixShape<InstructionShape::kM, InstructionShape::kK>,
typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount,
MmaCore::WarpCount::kK>;
// Define the Mma
using Mma = threadblock::ImplicitGemmFpropFusionMultistage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
arch::CacheOperation::Global,
IteratorScaleBias,
SmemIteratorScaleBias,
arch::CacheOperation::Always,
MmaPolicy,
WarpIteratorScaleBias,
Stages
>;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape,
WarpMmaTensorOp,
1,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv3dProblemSize
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,218 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped
matrix multiply-add with the appropriate threadblock-scoped epilogue.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_conv2d.h"
#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h"
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename OperatorClass,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = cutlass::sizeof_bits<ElementB>::value / cutlass::sizeof_bits<ElementB>::value
> struct DefaultDepthwiseFprop;
/////////////////////////////////////////////////////////////////////////////////////////////////
// OpClassSimt convolutions
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm,
/// 2 stage pipeline, and FFMA-based mainloop for SM50
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultDepthwiseFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassSimt,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
MathOperatorTag, // cutlass::arch::OpMultiplyAdd
IteratorAlgorithm::kAnalytic,
StrideSupport,
AlignmentA,
AlignmentB
> {
// Define the core components from GEMM
using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize<
ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
layout::RowMajor,
ElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
arch::OpClassSimt,
128,
sizeof_bits<ElementB>::value,
2,
MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA, LayoutA,
ThreadMapA
>
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB, LayoutB,
ThreadMapB,
AccessTypeB,
cutlass::conv::GroupMode::kDepthwise
>
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::DepthwiseFpropPipelined<
ThreadblockShape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
MmaPolicy
>;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt<
ThreadblockShape,
WarpMmaSimtOp,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv2dProblemSize,
cutlass::conv::GroupMode::kDepthwise
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -62,7 +62,8 @@ template <
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem
conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode
>
struct ImplicitGemmConvolution {
@ -117,6 +118,8 @@ struct ImplicitGemmConvolution {
/// Conv dimension and problem size structure (Conv2d or Conv3d)
using ConvProblemSize = ConvProblemSize_;
static conv::GroupMode const kGroupMode = GroupMode_;
/// Wgrad C stride idx for implicit gemm algorithm
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)
@ -198,6 +201,7 @@ struct ImplicitGemmConvolution {
int swizzle_log_tile;
int gemm_k_iterations;
int gemm_k_iterations_per_channel;
typename Mma::IteratorA::Params iterator_A;
typename Mma::IteratorA::Element const *ptr_A;
typename Mma::IteratorB::Params iterator_B;
@ -241,7 +245,12 @@ struct ImplicitGemmConvolution {
kConvolutionalOperator,
ThreadblockShape::kK,
args.problem_size,
kIteratorAlgorithm);
kIteratorAlgorithm,
kGroupMode,
ThreadblockShape::kN);
gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel(
kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm);
ThreadblockSwizzle threadblock_swizzle;
@ -286,6 +295,17 @@ struct ImplicitGemmConvolution {
// Compute position within threadblock
int thread_idx = threadIdx.x;
int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK;
if (kGroupMode != GroupMode::kNone) {
if (kGroupMode != GroupMode::kDepthwise) {
int k_per_group = params.problem_size.K / params.problem_size.groups;
int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group;
int channels_per_group = params.problem_size.C / params.problem_size.groups;
iterator_A_column_offset += group_idx * channels_per_group;
} else {
iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN;
}
}
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
@ -295,7 +315,7 @@ struct ImplicitGemmConvolution {
thread_idx,
MatrixCoord(
threadblock_tile_idx.m() * Mma::Shape::kM,
threadblock_tile_idx.k() * Mma::Shape::kK
iterator_A_column_offset
)
);
@ -327,7 +347,7 @@ struct ImplicitGemmConvolution {
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel);
//
// Epilogue

View File

@ -119,6 +119,8 @@ struct ImplicitGemmConvolutionFusion {
/// Conv dimension and problem size structure (Conv2d or Conv3d)
using ConvProblemSize = ConvProblemSize_;
static conv::GroupMode const kGroupMode = conv::GroupMode::kNone;
/// Wgrad C stride idx for implicit gemm algorithm
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)

View File

@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
/// Conv dimension and problem size structure (Conv2d or Conv3d)
using ConvProblemSize = ConvProblemSize_;
static conv::GroupMode const kGroupMode = conv::GroupMode::kNone;
/// Wgrad C stride idx for implicit gemm algorithm
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)
@ -488,4 +490,3 @@ struct ImplicitGemmConvolutionStridedDgrad {
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionWithFusedEpilogue {
/// Conv dimension and problem size structure (Conv2d or Conv3d)
using ConvProblemSize = ConvProblemSize_;
static conv::GroupMode const kGroupMode = conv::GroupMode::kNone;
/// Wgrad C stride idx for implicit gemm algorithm
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)

View File

@ -248,7 +248,7 @@ public:
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_HOST_DEVICE
CUTLASS_DEVICE
void advance() {
int next_idx = 0;
@ -263,18 +263,33 @@ public:
// Move filter_r by stride_h
filter_r_ += problem_size_.stride_h;
#if 0
bool check = (filter_r_ < problem_size_.R);
filter_r_ = check ? filter_r_ : start_r_;
next_idx = check ? 1 : 2;
reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_);
#else
asm volatile(
"{\n\t"
" .reg .pred %%p;\n\t"
" .reg .s64 t1;\n\t"
" setp.lt.s32 %%p, %3, %4;\n\t"
" selp.s32 %0, %3, %5, %%p;\n\t"
" selp.s32 %1, 1, 2, %%p;\n\t"
" selp.s64 t1, %6, %7, %%p;\n\t"
" add.s64 %2, %8, t1;\n\t"
"}\n"
: "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes)
: "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_),
"l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes));
#endif
}
// offset pointers by offset_bytes
pointer_ += (params_.inc_next[next_idx] - reset_bytes);
if (next_idx == 2) {
if (next_idx == 2) {
filter_k_ += params_.filter_k_delta;
}

View File

@ -528,7 +528,6 @@ public:
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
return TensorCoord(n, p, q, k);
}
/// Returns true if the current coordinate is within the output tensor Dy

View File

@ -321,7 +321,7 @@ public:
add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_HOST_DEVICE
CUTLASS_DEVICE
void advance() {
int next_idx = 0;
@ -336,8 +336,9 @@ public:
// Move filter_r by stride_h
filter_r_ += problem_size_.stride_h;
#if 0
if (filter_r_ < problem_size_.R) {
next_idx = 1;
// Restore bytes in q coordinate (Mma in filter s dimenstion)
@ -347,12 +348,25 @@ public:
// Restore filter_r
filter_r_ = start_r_;
next_idx = 2;
// Restore bytes in p and q coordinate (Mma in filter s and r dimenstion)
reset_bytes = reset_bytes_r_;
}
#else
asm volatile(
"{\n\t"
" .reg .pred %%p;\n\t"
" setp.lt.s32 %%p, %3, %4;\n\t"
" selp.s32 %0, %3, %5, %%p;\n\t"
" selp.s32 %1, 1, 2, %%p;\n\t"
" selp.s64 %2, %6, %7, %%p;\n\t"
"}\n"
: "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes)
: "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_),
"l"(reset_bytes_s_), "l"(reset_bytes_r_));
#endif
}
// offset pointers by offset_bytes

View File

@ -67,7 +67,8 @@ template <
typename Element_,
typename Layout_,
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>,
conv::GroupMode GroupMode_ = conv::GroupMode::kNone
>
class Conv2dFpropActivationTileAccessIteratorAnalytic {
public:
@ -89,6 +90,7 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static conv::GroupMode const kGroupMode = GroupMode_;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
@ -119,6 +121,11 @@ private:
int filter_c_;
int filter_r_;
int filter_s_;
int filter_c_init_;
int group_idx_offset_;
int channels_per_group_;
int crs_cnt_;
int crs_per_group_;
int offset_n_[ThreadMap::Iterations::kStrided];
int offset_p_[ThreadMap::Iterations::kStrided];
@ -137,6 +144,8 @@ public:
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
crs_cnt_(0),
group_idx_offset_(0),
filter_c_(0),
filter_r_(0),
filter_s_(0) {
@ -145,6 +154,12 @@ public:
filter_c_ = threadblock_offset.column() + thread_coord.contiguous();
if (kGroupMode != conv::GroupMode::kNone) {
filter_c_init_ = filter_c_;
channels_per_group_ = problem_size_.C / problem_size_.groups;
crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn);
}
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
@ -182,6 +197,10 @@ public:
CUTLASS_HOST_DEVICE
void advance() {
// moves to the next tile
if (kGroupMode != conv::GroupMode::kNone) {
++crs_cnt_;
}
++filter_s_;
if (filter_s_ < problem_size_.S) {
return;
@ -192,8 +211,19 @@ public:
return;
}
filter_r_ = 0;
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
if (kGroupMode == conv::GroupMode::kNone) {
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
} else {
if (crs_cnt_ == crs_per_group_) {
// moves to next group
crs_cnt_ = 0;
++group_idx_offset_;
filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_;
} else {
filter_c_ += Shape::kColumn * problem_size_.split_k_slices;
}
}
}
/// Returns the coordinate in the activations tensor X that is currently pointed to

View File

@ -66,7 +66,8 @@ template <
typename Element_,
typename Layout_,
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess>,
conv::GroupMode GroupMode_ = conv::GroupMode::kNone
>
class Conv2dFpropFilterTileAccessIteratorAnalytic {
public:
@ -88,6 +89,7 @@ public:
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static conv::GroupMode const kGroupMode = GroupMode_;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
@ -118,8 +120,14 @@ private:
int filter_r_;
int filter_s_;
int filter_c_;
int filter_c_init_;
int crs_cnt_;
int crs_per_group_;
int group_idx_offset_c_;
int channels_per_group_;
int offset_k_[ThreadMap::Iterations::kStrided];
int group_idx_offset_k_[ThreadMap::Iterations::kStrided];
public:
@ -134,6 +142,8 @@ public:
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
crs_cnt_(0),
group_idx_offset_c_(0),
filter_r_(0),
filter_s_(0),
filter_c_(0) {
@ -142,9 +152,23 @@ public:
filter_c_ = threadblock_offset.row() + thread_coord.contiguous();
if (kGroupMode != conv::GroupMode::kNone) {
filter_c_init_ = filter_c_;
if (kGroupMode == conv::GroupMode::kDepthwise){
channels_per_group_ = 1;
crs_per_group_ = problem_size_.S * problem_size_.R;
} else {
channels_per_group_ = problem_size_.C / problem_size_.groups;
crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow);
}
}
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) {
group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups);
}
}
set_iteration_index(0);
@ -168,6 +192,10 @@ public:
CUTLASS_HOST_DEVICE
void advance() {
// moves to the next tile
if (kGroupMode != conv::GroupMode::kNone) {
++crs_cnt_;
}
++filter_s_;
if (filter_s_ < problem_size_.S) {
return;
@ -179,8 +207,21 @@ public:
return;
}
filter_r_ = 0;
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
if (kGroupMode == conv::GroupMode::kNone) {
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
} else {
if (crs_cnt_ == crs_per_group_) {
crs_cnt_ = 0;
filter_c_ = filter_c_init_;
if (kGroupMode != conv::GroupMode::kDepthwise) {
// moves to next group
++group_idx_offset_c_;
}
} else {
filter_c_ += Shape::kRow * problem_size_.split_k_slices;
}
}
}
/// Returns the coordinate in the filter tensor W that is currently pointed to
@ -200,8 +241,14 @@ public:
TensorCoord coord = at();
return coord.n() < problem_size_.K &&
coord.c() < problem_size_.C;
if (kGroupMode == conv::GroupMode::kNone) {
return coord.n() < problem_size_.K && coord.c() < problem_size_.C;
} else if (kGroupMode == conv::GroupMode::kDepthwise) {
return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE.
} else {
return coord.n() < problem_size_.K && coord.c() < channels_per_group_ &&
group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_];
}
}
/// Returns a pointer to the vector starting at the current coordinate

View File

@ -554,20 +554,20 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams {
// next S
inc_next[0] = conv_sign * (
layout.stride()[0] * problem_size.dilation_w
(int64_t)layout.stride()[0] * problem_size.dilation_w
) * element_size_bits / 8;
// next R
inc_next[1] = conv_sign * (
layout.stride()[1] * problem_size.dilation_h
- (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
(int64_t)layout.stride()[1] * problem_size.dilation_h
- (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w
) * element_size_bits / 8;
// next K
inc_next[2] = (
threadblock_shape.column() * problem_size.split_k_slices
- conv_sign * (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h
- conv_sign * (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w
- conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h
- conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w
) * element_size_bits / 8;
// logical offset added to internal channel counter - units are elements, not bytes
@ -614,12 +614,12 @@ struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams {
// next S
inc_next[0] = conv_sign * (
layout.stride()[0] * problem_size.dilation_w
(int64_t)layout.stride()[0] * problem_size.dilation_w
) * element_size_bits / 8;
// next R
inc_next[1] = conv_sign * (
layout.stride()[1] * problem_size.dilation_h
(int64_t)layout.stride()[1] * problem_size.dilation_h
) * element_size_bits / 8;
// next K
@ -670,18 +670,18 @@ struct Conv2dDgradFilterIteratorOptimizedParams {
TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter",
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8;
inc_next_rs =
( layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
( (int64_t)layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2]
) * element_size_bits / 8;
inc_next_k =
(
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2]
- (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2]
) * element_size_bits / 8;
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
@ -730,26 +730,26 @@ struct Conv2dStridedDgradFilterIteratorOptimizedParams {
// next S
inc_next[0] =
( layout.stride()[0] * problem_size.stride_w
( (int64_t)layout.stride()[0] * problem_size.stride_w
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
) * element_size_bits / 8;
// next R
inc_next[1] =
( layout.stride()[1] * problem_size.stride_h
( (int64_t)layout.stride()[1] * problem_size.stride_h
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
) * element_size_bits / 8;
// next K
inc_next[2] =
(
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2]
threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2]
//- (problem_size.R * problem_size.S - 1) * layout.stride()[0]
//- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
) * element_size_bits / 8;
// offset in units of bytes to move the pointer in backward direction
reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2]
reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2]
* element_size_bits / 8;
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
@ -800,13 +800,13 @@ struct Conv2dWgradOutputGradientIteratorOptimizedParams {
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
// Incremental offsets in unites of bytes (number of elements) * sizeof_bits<Element>::value / 8
offset_next_strided = (threadmap_delta.strided() * layout.stride()[0])
offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0])
* element_size_bits / 8;
offset_next_contiguous = (threadmap_delta.contiguous())
* element_size_bits / 8;
inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0])
inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0])
* element_size_bits / 8;
}
};
@ -891,4 +891,3 @@ struct PredicatedScaleBiasVectorAccessIteratorParams {
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -104,6 +104,11 @@ public:
return TileAccessIterator::getParams(problem_size, layout);
}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
tile_access_iterator_.set_iteration_index(index);
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE

View File

@ -304,8 +304,8 @@ struct Conv3dDgradOutputGradientIteratorOptimizedParams {
// logical offset added to internal channel counter - units are elements, not bytes
filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters object for Conv2d DGRAD Filter (w) iterator
@ -343,18 +343,18 @@ struct Conv3dDgradFilterIteratorOptimizedParams {
TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter",
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
inc_next_strided = (layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8;
inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8;
inc_next_trs =
( layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3]
( (int64_t)layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3]
) * element_size_bits / 8;
inc_next_k =
(
threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[3]
- (problem_size.T * problem_size.R * problem_size.S - 1) * layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3]
threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3]
- (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0]
- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3]
) * element_size_bits / 8;
filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices;
@ -408,13 +408,13 @@ struct Conv3dWgradOutputGradientIteratorOptimizedParams {
element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta);
// Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8
offset_next_strided = (threadmap_delta.strided() * layout.stride()[0])
offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0])
* element_size_bits / 8;
offset_next_contiguous = (threadmap_delta.contiguous())
* element_size_bits / 8;
inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0])
inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0])
* element_size_bits / 8;
// Precompute several quantities for fast modulo arithmetic.

View File

@ -0,0 +1,336 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to A operand
typename TransformA_ = NumericArrayConverter<
typename SmemIteratorA_::Element,
typename IteratorA_::Element,
IteratorA_::Fragment::kElements>,
///
/// Transformation applied to A operand
typename TransformB_ = NumericArrayConverter<
typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool
>
class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase<Shape_, Policy_, 2> {
public:
///< Base class
using Base = gemm::threadblock::MmaBase<Shape_, Policy_, 2>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DepthwiseFpropPipelined(
typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC &accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
FragmentC const &src_accum, ///< source accumulator tile
int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel
TransformA transform_A = TransformA(), ///< transformation applied to A fragment
TransformB transform_B = TransformB()) { ///< transformation applied to B fragment
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Depthwise specific
int channel_start_index = 0;
int rs_plane_idx = 0;
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
if(rs_plane_idx == gemm_k_iterations_per_channel - 1){
// Reset interation index.
iterator_B.set_iteration_index(0);
}
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
__syncthreads();
if(rs_plane_idx == gemm_k_iterations_per_channel - 1){
// Move to next set of filter groups.
channel_start_index += Base::kWarpGemmIterations;
}
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
}
else {
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
}
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
warp_frag_B[warp_mma_k % 2], accum);
}
rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,337 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data
layout of the global memory fragments, data types, and internal tile sizes.
Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
#include "cutlass/gemm/threadblock/mma_singlestage.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/conv/warp/mma_depthwise_simt.h"
#include "cutlass/arch/cache_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
template <
/// Shape of threadblock-scoped matrix multiply operator
typename Shape,
/// Shape of warp-level matrix multiply operator
typename WarpShape,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape,
/// Element data type of A operand
typename ElementA,
/// Layout of operand A
typename LayoutA,
/// Element data type of B operand
typename ElementB,
/// Layout of operand B
typename LayoutB,
/// Data type of accumulator
typename ElementC,
/// Layout of accumulator
typename LayoutC,
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
typename OperatorClass,
/// Size of a warp-scoped per thread access
int kLaneAccessSizeA_ = 0,
/// Size of a warp-scoped per thread access
int kLaneAccessSizeB_ = 0,
/// Number of stages
int Stages = 2,
/// Operation performed by MMA
typename Operator = typename platform::conditional<
(platform::is_same<OperatorClass,
cutlass::arch::OpClassTensorOp>::value) &&
(platform::is_same<ElementA, int8_t>::value ||
platform::is_same<ElementA, int4b_t>::value ||
platform::is_same<ElementA, uint8_t>::value ||
platform::is_same<ElementA, uint4b_t>::value),
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA =
cutlass::arch::CacheOperation::Global,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB =
cutlass::arch::CacheOperation::Global,
/// per-element transformation for elements of A
ComplexTransform TransformA = ComplexTransform::kNone,
/// per-element transformation for elements of B
ComplexTransform TransformB = ComplexTransform::kNone,
bool IsComplex = false // (is_complex<ElementA>::value || is_complex<ElementB>::value)
>
struct DepthwiseMmaCoreWithLaneAccessSize;
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Shape of threadblock-scoped matrix multiply operator
typename Shape,
/// Shape of warp-level matrix multiply operator
typename WarpShape,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape,
/// Element data type of A operand
typename ElementA,
/// Layout of operand A
typename LayoutA,
/// Element data type of B operand
typename ElementB,
/// Layout of operand B
typename LayoutB,
/// Data type of accumulator
typename ElementC,
/// Layout of accumulator
typename LayoutC,
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
typename OperatorClass,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// per-element transformation for elements of A
ComplexTransform TransformA,
/// per-element transformation for elements of B
ComplexTransform TransformB,
bool IsComplex
>
struct DepthwiseMmaCoreWithLaneAccessSize<
Shape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor,
CacheOpA, CacheOpB, TransformA, TransformB, IsComplex
> : cutlass::gemm::threadblock::DefaultMmaCore<
Shape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
OperatorClass, Stages, Operator, AccumulatorsInRowMajor,
CacheOpA, CacheOpB, TransformA, TransformB, IsComplex
> {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: row-major
/// B: column-major
/// Operator: simt class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Size of a warp-scoped per thread access (a value of -1 indicates the default)
int kLaneAccessSizeA_,
/// Size of a warp-scoped per thread access (a value of -1 indicates the default)
int kLaneAccessSizeB_,
/// Operation performed by GEMM
typename Operator_>
struct DepthwiseMmaCoreWithLaneAccessSize<Shape_,
WarpShape_,
cutlass::gemm::GemmShape<1, 1, 1>,
ElementA_,
layout::RowMajor,
ElementB_,
layout::ColumnMajor,
ElementC_,
LayoutC_,
arch::OpClassSimt,
kLaneAccessSizeA_,
kLaneAccessSizeB_,
2,
Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore<Shape_,
WarpShape_,
cutlass::gemm::GemmShape<1, 1, 1>,
ElementA_,
layout::RowMajor,
ElementB_,
layout::ColumnMajor,
ElementC_,
LayoutC_,
arch::OpClassSimt,
2,
Operator_> {
using Base = cutlass::gemm::threadblock::DefaultMmaCore<Shape_,
WarpShape_,
cutlass::gemm::GemmShape<1, 1, 1>,
ElementA_,
layout::RowMajor,
ElementB_,
layout::ColumnMajor,
ElementC_,
LayoutC_,
arch::OpClassSimt,
2,
Operator_>;
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = ElementB_;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using OperatorClass = arch::OpClassSimt;
static int const kLaneAccessSizeA = kLaneAccessSizeA_;
static int const kLaneAccessSizeB = kLaneAccessSizeB_;
// Divisility requirements
static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0,
"Size of a warp-scoped per thread access should be larger then ZERO" );
/// Default Operator
using Operator = Operator_;
/// Number of warps present
using WarpCount = typename Base::WarpCount;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) &&
!(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
);
/// Number of threads per warp
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
static int const kElementsPerAccess = 1;
//
// Shared memory layouts
//
using SmemLayoutA = layout::ColumnMajor;
using SmemLayoutB = layout::RowMajor;
//
// Iterators to write to shared memory are same as base class
//
//
// Warp-level matrix multiply operator
//
// Define the warp-level op
static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m<WarpShape>();
static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
"WarpShape must be divisible by ThreadTile shape.");
static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
static const int numElementsA = kLaneAccessSizeA / sizeof_bits<ElementA>::value;
static const int numElementsB = kLaneAccessSizeB / sizeof_bits<ElementB>::value;
static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value);
static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value);
static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN),
"Padding must be divisible by Lane");
// these should have max of thread tile also
using LaneMmaShape = cutlass::gemm::GemmShape<
LaneM,
LaneN,
1>;
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
LaneMmaShape
>;
using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt<
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<>
ElementA, /// Data type of A elements
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
ElementB, /// Data type of B elements
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
ElementC, /// Element type of C matrix
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy)
>;
/// Policy used to define MmaPipelined
using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy<
MmaWarpSimt,
MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
WarpCount::kK
>;
};
} // namespace threadblock
} // namespace conv
} // namespace cutlass

View File

@ -64,7 +64,7 @@
#include "cutlass/arch/cache_operation.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
#include "cutlass/conv/warp/scale_bias_relu_transform.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -139,6 +139,13 @@ class MmaFpropFusionBase {
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
static_assert((kWarpGemmIterations % 2) == 0,
"Inner loop iteration must be an even number.");
//
// Nested structs
//
@ -319,7 +326,7 @@ class ImplicitGemmFpropFusionMultistage
using Policy = Policy_;
///< Base class
using Base = MmaFpropFusionBase<Shape_, typename IteratorScaleBias::Element,
typename IteratorScaleBias::Layout, Policy_,
typename IteratorScaleBias::Layout, Policy,
WarpIteratorScaleBias, Stages>;
using SmemIteratorA = SmemIteratorA_;
@ -518,6 +525,8 @@ public:
IteratorScaleBias iterator_A_scale_bias,
///< initial value of accumulator
FragmentC const &src_accum,
///< number of iterations per channel
int gemm_k_iterations_per_channel = 0,
///< Imaginary strides used for planar-complex only - ignored here
int64_t imag_stride_A = 0,
int64_t imag_stride_B = 0) {

View File

@ -116,10 +116,6 @@ public:
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
@ -272,6 +268,8 @@ public:
IteratorB iterator_B,
///< initial value of accumulator
FragmentC const &src_accum,
///< number of iterations per channel
int gemm_k_iterations_per_channel = 0,
///< Imaginary strides used for planar-complex only - ignored here
int64_t imag_stride_A = 0,
int64_t imag_stride_B = 0) {
@ -297,7 +295,7 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
@ -322,7 +320,7 @@ public:
this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /

View File

@ -188,6 +188,7 @@ public:
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
FragmentC const &src_accum, ///< source accumulator tile
int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel
TransformA transform_A = TransformA(), ///< transformation applied to A fragment
TransformB transform_B = TransformB()) { ///< transformation applied to B fragment

View File

@ -70,7 +70,7 @@
#include "cutlass/arch/cache_operation.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h"
#include "cutlass/gemm/warp/scale_bias_tile_iterator.h"
#include "cutlass/conv/warp/scale_bias_relu_transform.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -138,6 +138,13 @@ class MmaWgradFusionBase {
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
static_assert((kWarpGemmIterations % 2) == 0,
"Inner loop iteration must be an even number.");
//
// Nested structs
//
@ -306,10 +313,6 @@ class ImplicitGemmWgradFusionMultistage
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
@ -470,6 +473,8 @@ public:
IteratorScaleBias iterator_B_scale_bias,
///< initial value of accumulator
FragmentC const &src_accum,
///< number of iterations per channel
int gemm_k_iterations_per_channel = 0,
///< Imaginary strides used for planar-complex only - ignored here
int64_t imag_stride_A = 0,
int64_t imag_stride_B = 0) {

View File

@ -113,12 +113,9 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
/// Internal pointer to first access of tile
BytePointer pointer_;
/// Size of tensor
Conv2dProblemSize problem_size_;
int filter_c_;
int filter_r_;
int filter_s_;
int problem_size_trs;
int problem_size_c;
int filter_trs_;
TensorCoord thread_offset_;
@ -140,10 +137,43 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
/// Initial offset of threadblock
TensorCoord const &threadblock_offset)
: params_(params),
problem_size_(problem_size),
filter_c_(0),
filter_r_(0),
filter_s_(0) {
problem_size_trs(problem_size.R * problem_size.S),
problem_size_c(problem_size.C),
filter_trs_(0) {
pointer_ = (thread_id < kThreads)
? reinterpret_cast<BytePointer>(
const_cast<NonConstPointer>(scale_pointer))
: reinterpret_cast<BytePointer>(
const_cast<NonConstPointer>(bias_pointer));
// Per-thread offset in logical coordinates of tensor
int thread_base = (thread_id < kThreads) ? 0 : kThreads;
thread_offset_ =
threadblock_offset +
TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0);
set_iteration_index(0);
}
CUTLASS_HOST_DEVICE
PredicatedScaleBiasVectorAccessIterator(
/// Precomputed parameters object
Params const &params,
/// Extent of tensor
Conv3dProblemSize const &problem_size,
/// Pointer to the start of the scale vector
ConstPointer scale_pointer,
/// Pointer to the start of the bias vector
ConstPointer bias_pointer,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const &threadblock_offset)
: params_(params),
problem_size_trs(problem_size.T * problem_size.R * problem_size.S),
problem_size_c(problem_size.C),
filter_trs_(0) {
pointer_ = (thread_id < kThreads)
? reinterpret_cast<BytePointer>(
const_cast<NonConstPointer>(scale_pointer))
@ -177,6 +207,22 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
scale_pointer, bias_pointer,
thread_id, make_Coord(0, 0)) {}
CUTLASS_HOST_DEVICE
PredicatedScaleBiasVectorAccessIterator(
/// Precomputed parameters object
Params const &params,
/// Extent of tensor
Conv3dProblemSize const &problem_size,
/// Pointer to start of scale vector
ConstPointer scale_pointer,
/// Pointer to start of scale vector
ConstPointer bias_pointer,
///< ID of each participating thread
int thread_id)
: PredicatedScaleBiasVectorAccessIterator(params, problem_size,
scale_pointer, bias_pointer,
thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) {}
@ -209,16 +255,10 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
CUTLASS_HOST_DEVICE
void advance() {
// moves to the next tile
++filter_s_;
if (filter_s_ == problem_size_.S) {
filter_s_ = 0;
++filter_r_;
if (filter_r_ < problem_size_.R) {
} else {
filter_r_ = 0;
add_tile_offset(TensorCoord(1, 0));
}
++filter_trs_;
if (filter_trs_ == problem_size_trs) {
filter_trs_ = 0;
add_tile_offset(TensorCoord(1, 0));
}
}
@ -248,7 +288,7 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
"}\n" : "+r"(enabled) :"n"(kThreads * 2));
#endif
return ((thread_offset_.contiguous() < problem_size_.C) && enabled);
return ((thread_offset_.contiguous() < problem_size_c) && enabled);
}
};
@ -322,6 +362,25 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
layout::PitchLinearCoord(threadblock_offset.column(),
threadblock_offset.row())) {}
CUTLASS_HOST_DEVICE
PredicatedScaleBiasVectorAccessIterator(
///< Precomputed parameters object
Params const &params,
///< Extent of tensor
Conv3dProblemSize const &problem_size,
///< Pointer to the start of the scale vector
ConstPointer scale_pointer,
///< Pointer to the start of the bias vector
ConstPointer bias_pointer,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const &threadblock_offset)
: iterator_(params, problem_size, scale_pointer, bias_pointer,
thread_id,
layout::PitchLinearCoord(threadblock_offset.column(),
threadblock_offset.row())) {}
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
CUTLASS_HOST_DEVICE
PredicatedScaleBiasVectorAccessIterator(
@ -335,6 +394,18 @@ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
scale_pointer, bias_pointer,
thread_id, make_Coord(0, 0)) {}
CUTLASS_HOST_DEVICE
PredicatedScaleBiasVectorAccessIterator(
Params const &params, ///< Precomputed parameters object
Conv3dProblemSize const &problem_size, ///< Extent of tensor
ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
int thread_id ///< ID of each participating thread
)
: PredicatedScaleBiasVectorAccessIterator(params, problem_size,
scale_pointer, bias_pointer,
thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }

View File

@ -157,7 +157,6 @@ struct StridedDgradIdentityThreadblockSwizzle :
split_k_slices);
}
/// Returns the shape of the problem in units of logical tiles
/// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape())
private:

View File

@ -0,0 +1,163 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/thread/mma.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_simt_policy.h"
#include "cutlass/gemm/warp/mma_simt.h"
#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Number of partitions along K dimension
int PartitionsK = 1,
/// Complex transformation on operand A
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex transformation on operand B
ComplexTransform TransformB = ComplexTransform::kNone,
/// Used for partial specialization
typename Enable = bool>
class MmaDepthwiseSimt
: public cutlass::gemm::warp::
MmaSimt<Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, Policy_> {
using Base = cutlass::gemm::warp::
MmaSimt<Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, Policy_>;
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_; // < 64, 16 , 8>
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassSimt;
/// Hard-coded for now
using ArchTag = arch::Sm50;
/// Complex transform on A operand
static ComplexTransform const kTransformA = TransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = TransformB;
public:
/// Iterates over the B operand in memory
using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator<
MatrixShape<Policy::LaneMmaShape::kK, Shape::kN>,
cutlass::gemm::Operand::kB,
ElementB,
LayoutB,
Policy,
PartitionsK,
Shape::kK
>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed A tile
using TransformedFragmentB = FragmentB;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaDepthwiseSimt():Base() {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,255 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT
instructions
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma_simt_policy.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions
///
/// concept: MutableRandomAccessContiguousTileIteratorConcept
///
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
cutlass::gemm::Operand Operand,
/// Data type of A elements
typename Element_,
/// Layout of operand
typename Layout_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Number of partitions along K dimension - used in sliced-K
int PartitionsK = 1,
/// Group Size along kPartition - used in sliced-K
int PartitionGroupSize = 1
>
class DepthwiseMmaSimtTileIterator;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Specialization for B operands of row-major layouts
///
/// Concept: MutableRandomAccessContiguousTileIteratorConcept
///
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Data type of A elements
typename Element_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Number of partitions along K dimension
int PartitionsK,
/// Group Size along kPartition - used in sliced-K
int PartitionGroupSize>
class DepthwiseMmaSimtTileIterator<Shape_,
cutlass::gemm::Operand::kB,
Element_,
layout::RowMajor,
Policy_,
PartitionsK,
PartitionGroupSize>
: public cutlass::gemm::warp::MmaSimtTileIterator<Shape_,
cutlass::gemm::Operand::kB,
Element_,
layout::RowMajor,
Policy_,
PartitionsK,
PartitionGroupSize> {
using Base = cutlass::gemm::warp::MmaSimtTileIterator<Shape_,
cutlass::gemm::Operand::kB,
Element_,
layout::RowMajor,
Policy_,
PartitionsK,
PartitionGroupSize>;
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = Shape_;
/// Operand tag
static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB;
/// Element type
using Element = Element_;
/// Layout of policy
using Layout = layout::RowMajor;
/// Decomposition of elements among threads
using Policy = Policy_;
/// TensorRef type for loading element from a tensor
using TensorRef = typename Base::TensorRef;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Thread-level shape of a fragment
using ThreadShape = typename Base::ThreadShape;
/// Number of individual loads
using Iterations = typename Base::Iterations;
/// Fragment object holding a thread's part of a tile
using Fragment = typename Base::Fragment;
static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim");
private:
MatrixCoord lane_offset_;
int channel_idx_;
int base_channel_idx_;
int warps_n_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
DepthwiseMmaSimtTileIterator():Base() { }
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
DepthwiseMmaSimtTileIterator(
TensorRef ref,
int lane_id
) : Base(ref, lane_id) {
// compute offset based on thread ID and lane layout
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
warps_n_ = -1;
channel_idx_ = 0;
base_channel_idx_ = 0;
lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN);
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
if(warps_n_ == -1){
warps_n_ = coord.column();
}
Base::add_tile_offset(coord);
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator. (vector loads)
CUTLASS_HOST_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < Iterations::kRow; ++k) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Iterations::kColumn; ++n) {
void const *ptr = this->ref_.data() +
this->ref_.offset({-(channel_idx_ - base_channel_idx_),
n * Policy::WarpShape::kColumn}) +
pointer_offset / Policy::LaneMmaShape::kN;
// Base_k of a warp + Base_k of current threads.
int thread_k_base_idx =
warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column();
if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) {
// Depthwise kernel would only do computation when channel == k.
// Loads an element when the current computation channel == the k corresponding to this thread.
arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr);
} else {
// Reduce SMEM load
dst_ptr[n + k * Iterations::kColumn].fill(Element(0));
}
}
}
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
/// Notify the iterator which k-group it is currently pointing to.
///
/// This does not advance the iterator. Rather, it overrides its internal
/// tracking with constant-valued k-group index
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
if(k_group % PartitionGroupSize == 0 && k_group != 0){
base_channel_idx_ = k_group;
}
channel_idx_ = k_group;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass

View File

@ -101,7 +101,7 @@ struct FpropScaleBiasReluTransform {
"}\n"
: "=r"(ptr_activations[0])
: "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]),
"r"(ptr_scale_bias[1]), "n"(0x7eff7eff));
"r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2));
#else
// TODO: write emulation code
assert(0);
@ -151,8 +151,8 @@ struct WgradScaleBiasReluTransform {
#if 1
// CUDA + PTX version
bool h1_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].x) == 0x7eff);
bool h2_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].y) == 0x7eff);
bool h1_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16);
bool h2_oob = (reinterpret_cast<uint16_t &>(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16);
// Apply per channel scale+bias+relu if the data is not a special NaN
// (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0.
@ -161,7 +161,7 @@ struct WgradScaleBiasReluTransform {
// out-of-bound because C x R x S can be an odd number.
asm volatile(
"{\n\t"
" fma.rn.f16x2.relu %0 , %1, %2, %3;\n"
" fma.rn.f16x2.relu %0, %1, %2, %3;\n"
"}"
: "=r"(reinterpret_cast<uint32_t &>(ptr_activations[0]))
: "r"(ptr_scale_bias[0]), "r"(reinterpret_cast<uint32_t &>(ptr_activations[0])),
@ -195,7 +195,7 @@ struct WgradScaleBiasReluTransform {
"}\n"
: "=r"(reinterpret_cast<uint32_t &>(ptr_activations[0]))
: "r"(ptr_scale_bias[0]), "r"(reinterpret_cast<uint32_t &>(ptr_activations[0])),
"r"(ptr_scale_bias[1]), "n"(0x7eff), "n"(0xffff0000), "n"(0x0000ffff));
"r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff));
#endif
#else
// TODO: write emulation code

View File

@ -43,7 +43,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0)
#define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr)
#if !defined(__CUDACC_RTC__)
@ -192,4 +192,3 @@ CUTLASS_HOST_DEVICE bool thread0() {
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -98,6 +98,32 @@ struct ReLu<Array<T, N>> {
}
};
// Leaky Relu operator
template <typename T>
struct LeakyReLU {
CUTLASS_HOST_DEVICE
T operator()(T const &value, T const & alpha_recip) const {
T res = value > T(0) ? value : value * alpha_recip;
return res;
}
};
template <typename T, int N>
struct LeakyReLU<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
Array<T, N> y;
LeakyReLU<T> leaky_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < int(rhs.size()); ++i) {
y[i] = leaky_op(rhs[i], alpha_recip);
}
return y;
}
};
// Tanh operator
template <typename T>
struct Tanh {
@ -135,32 +161,6 @@ struct Tanh<Array<half_t, N>> {
}
};
// Leaky Relu operator
template <typename T>
struct LeakyReLU {
CUTLASS_HOST_DEVICE
T operator()(T const &value, T const & alpha_recip) const {
T res = value > T(0) ? value : value * alpha_recip;
return res;
}
};
template <typename T, int N>
struct LeakyReLU<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
Array<T, N> y;
LeakyReLU<T> leaky_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < int(rhs.size()); ++i) {
y[i] = leaky_op(rhs[i], alpha_recip);
}
return y;
}
};
// Sigmoid operator
template <typename T>
struct Sigmoid {

View File

@ -157,7 +157,7 @@ public:
if (k_partition) {
beta_ = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
skip_elementwise_ = true;
}

View File

@ -65,6 +65,8 @@
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/layout/permute.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -79,7 +81,8 @@ template <
typename WarpMmaSimt_,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueSimt {
@ -109,7 +112,8 @@ struct DefaultEpilogueSimt {
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
OutputTileThreadMap,
ElementOutput,
ScatterD
ScatterD,
PermuteDLayout
>;
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
@ -310,7 +314,6 @@ struct DefaultEpilogueSimtAffineRankN {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -74,6 +74,8 @@
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -166,7 +168,7 @@ template <
typename ThreadMap
>
struct DefaultIteratorsTensorOp<float, int32_t, 4, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape,
InstructionShape,
@ -265,7 +267,7 @@ struct DefaultIteratorsTensorOp<
layout::RowMajor
>;
using WarpTileIterator = typename cutlass::platform::conditional<
using WarpTileIterator = typename platform::conditional<
(ThreadblockShape::kN == 256),
WarpTileIteratorNotMixed,
WarpTileIteratorMixed>::type;
@ -284,7 +286,7 @@ struct DefaultIteratorsTensorOp<
int32_t
>;
using SharedLoadIterator = typename cutlass::platform::conditional<
using SharedLoadIterator = typename platform::conditional<
(ThreadblockShape::kN == 256),
SharedLoadIteratorNotMixed,
SharedLoadIteratorMixed>::type;
@ -302,7 +304,8 @@ template <
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueTensorOp {
@ -334,6 +337,7 @@ struct DefaultEpilogueTensorOp {
OutputTileThreadMap,
ElementOutput,
ScatterD,
PermuteDLayout,
UseCUDAStore
>;
@ -570,7 +574,6 @@ struct DefaultEpilogueTensorOpAffineRankN {
};
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for TensorOps which uses
/// intereleaved output layout. For this case, shared memory is not needed.
template <typename Shape_, typename WarpMmaTensorOp_, int PartitionsK,

View File

@ -66,6 +66,8 @@
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/layout/permute.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -81,7 +83,8 @@ template <
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueVoltaTensorOp {
@ -111,7 +114,8 @@ struct DefaultEpilogueVoltaTensorOp {
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
OutputTileThreadMap,
ElementOutput,
ScatterD
ScatterD,
PermuteDLayout
>;
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp<
@ -326,7 +330,6 @@ struct DefaultEpilogueVoltaTensorOpAffineRankN {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -49,6 +49,8 @@
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -67,7 +69,8 @@ template <
typename ElementVector,
typename OutputOp,
int ElementsPerAccess,
bool ScatterD = false
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueWithBroadcastTensorOp {
@ -86,7 +89,8 @@ struct DefaultEpilogueWithBroadcastTensorOp {
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
typename Base::OutputTileThreadMap,
ElementOutput,
ScatterD
ScatterD,
PermuteDLayout
>;
//

View File

@ -50,6 +50,8 @@
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -67,7 +69,8 @@ template <
typename OutputOp,
typename ReductionOp,
int ElementsPerAccess,
bool ScatterD = false
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueWithReductionTensorOp {
@ -89,7 +92,8 @@ struct DefaultEpilogueWithReductionTensorOp {
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
typename Base::OutputTileThreadMap,
ElementOutput,
ScatterD
ScatterD,
PermuteDLayout
>;
/// Define the epilogue
@ -120,7 +124,8 @@ template <
typename OutputOp,
typename ReductionOp,
int ElementsPerAccess,
bool ScatterD = false
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueWithReductionVoltaTensorOp {
@ -142,7 +147,8 @@ struct DefaultEpilogueWithReductionVoltaTensorOp {
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
typename Base::OutputTileThreadMap,
ElementOutput,
ScatterD
ScatterD,
PermuteDLayout
>;
/// Define the epilogue

View File

@ -64,6 +64,8 @@
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -79,7 +81,8 @@ template <
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultEpilogueWmmaTensorOp {
@ -109,7 +112,8 @@ struct DefaultEpilogueWmmaTensorOp {
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
OutputTileThreadMap,
ElementOutput,
ScatterD
ScatterD,
PermuteDLayout
>;
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp<

View File

@ -0,0 +1,513 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue.
The epilogue finds max values in each row of the row-major output matrix and stores them.
The max values are also used for a further round of threadblock scoped reduction operation, where
the partial reduction results are stored in a pre-allocated array and used for further full reduction.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/fast_math.h"
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <
typename ThreadblockShape_,
int ThreadCount,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementNorm_,
typename ElementSum_,
typename ElementSoftmaxCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false
>
class EpilogueVisitorSoftmax {
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using ElementNorm = ElementNorm_;
using ElementSum = ElementSum_;
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using SoftmaxFragment = Array<ElementSoftmaxCompute, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
using TensorRefD = TensorRef<ElementOutput, LayoutOutput>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
static bool const kUseMasking = UseMasking_;
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_C;
int64_t batch_stride_D;
int64_t batch_stride_Max;
int64_t batch_stride_Sum;
//
// Methods
//
Arguments():
batch_stride_C(0),
batch_stride_D(0),
batch_stride_Max(0),
batch_stride_Sum(0)
{
}
Arguments(
typename ElementwiseFunctor::Params elementwise_
):
elementwise(elementwise_),
batch_stride_C(0),
batch_stride_D(0),
batch_stride_Max(0),
batch_stride_Sum(0)
{
}
Arguments(
typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_C_,
int64_t batch_stride_D_,
int64_t batch_stride_Max_,
int64_t batch_stride_Sum_
):
elementwise(elementwise_),
batch_stride_C(batch_stride_C_),
batch_stride_D(batch_stride_D_),
batch_stride_Max(batch_stride_Max_),
batch_stride_Sum(batch_stride_Sum_)
{
}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_C;
int64_t batch_stride_D;
int64_t batch_stride_Max;
int64_t batch_stride_Sum;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const &args):
elementwise(args.elementwise),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
batch_stride_Max(args.batch_stride_Max),
batch_stride_Sum(args.batch_stride_Sum)
{
}
};
/// Shared storage
struct SharedStorage {
};
private:
Params const & params_;
SharedStorage & shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator alpha_;
ElementAccumulator beta_;
ElementNorm *ptr_Max_;
ElementSum *ptr_Sum_;
int column_offset_;
ElementSoftmaxCompute accum_max_;
ElementSoftmaxCompute accum_sum_;
MatrixCoord thread_offset_;
float infinity_;
public:
CUTLASS_DEVICE
EpilogueVisitorSoftmax(
Params const &params,
SharedStorage &shared_storage,
cutlass::MatrixCoord const &problem_size,
int thread_idx,
int warp_idx,
int lane_idx,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D,
typename OutputTileIterator::Element *ptr_C,
typename OutputTileIterator::Element *ptr_D,
ElementNorm *ptr_Max = nullptr,
ElementSum *ptr_Sum = nullptr,
cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0),
float infinity = 10000.0f
):
params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
ptr_Max_(ptr_Max),
ptr_Sum_(ptr_Sum),
column_offset_(column_offset),
extent_real_(problem_size_real),
infinity_(infinity)
{
alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha);
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator()) {
iterator_C_.clear_mask();
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
fragment_C_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
// Clear accumulators for max and sum when starting a whole row
clear_accum_();
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorFragment const &accum) {
using Mul = cutlass::multiplies<SoftmaxFragment>;
using Minus = cutlass::minus<SoftmaxFragment>;
using Exp = cutlass::fast_exp_op<SoftmaxFragment>;
Minus minus;
Exp exponential;
SoftmaxFragment result;
NumericArrayConverter<ElementSoftmaxCompute, ElementOutput, kElementsPerAccess> source_converter;
OutputVector &source_vector = reinterpret_cast<OutputVector *>(&fragment_C_)[frag_idx];
if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
result = source_converter(elementwise_(accum));
}else{
result = source_converter(elementwise_(accum, source_vector));
}
thread_offset_ =
iterator_D_.thread_start() +
OutputTileIterator::ThreadMap::iteration_offset(frag_idx);
bool column_guard = (thread_offset_.column() < extent_.column());
if (kUseMasking) {
int elements_in_boundary = extent_real_.column() - thread_offset_.column();
elements_in_boundary = (elements_in_boundary > kElementsPerAccess) ? kElementsPerAccess : elements_in_boundary;
elementwise_padding_(result, elements_in_boundary);
}
ElementSoftmaxCompute accum_max_prev = accum_max_;
// Compute the maximum within one row
if (!column_idx) {
// This is the first fragment in a new row
if (column_guard) {
accum_max_ = maximum_accumulator_(result);
}
}
else {
// This is an additional fragment in the same row
if (column_guard) {
accum_max_ = maximum_accumulator_(result, accum_max_);
}
}
// proactively compute max in warps
accum_max_ = warp_reduce_max_(accum_max_);
ElementSoftmaxCompute updater = fast_exp(accum_max_prev - accum_max_);
SoftmaxFragment intermediate = exponential(minus(result, accum_max_));
if (kHasMultiStepsInRow) {
if (!column_idx) {
accum_sum_ = (column_guard) ? \
sum_accumulator_(intermediate) : ElementSoftmaxCompute(0);
} else {
// Algorithm in $3.1, https://arxiv.org/pdf/2205.14135v1.pdf
// S* = S* x updater + sum_row(P'), where updater = exp(M* - M_row)
accum_sum_ = (column_guard) ? \
sum_accumulator_(intermediate, accum_sum_ * updater) : accum_sum_ * updater;
}
} else {
accum_sum_ = (column_guard) ? sum_accumulator_(intermediate, accum_sum_) : ElementSoftmaxCompute(0);
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementSoftmaxCompute, kElementsPerAccess> output_converter;
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the end of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
ConvertSumOutput convert_sum_output;
ConvertNormOutput convert_norm_output;
// Compute accumulate sum only in the last step
accum_sum_ = warp_reduce_sum_(accum_sum_);
bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0);
bool row_guard = thread_offset_.row() < extent_.row();
bool is_write_thread = row_guard && is_first_thread_in_tile;
int block_batch = blockIdx.z;
ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Max;
ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Sum;
arch::global_store<ElementNorm, sizeof(ElementNorm)>(
convert_norm_output(accum_max_),
(void *)curr_ptr_max,
is_write_thread);
arch::global_store<ElementSum, sizeof(ElementSum)>(
convert_sum_output(accum_sum_),
(void *)curr_ptr_sum,
is_write_thread);
// Clear accumulators for max and sum when finishing a whole row
clear_accum_();
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {
}
private:
CUTLASS_DEVICE
void elementwise_padding_(SoftmaxFragment &result, int elements_in_boundary) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
result[i] = (i < elements_in_boundary) ? result[i] : ElementSoftmaxCompute(-infinity_);
}
}
CUTLASS_DEVICE
ElementSoftmaxCompute warp_reduce_sum_(ElementSoftmaxCompute sum_) {
int half_thread_in_row = (kThreadsPerRow >> 1);
CUTLASS_PRAGMA_UNROLL
for (int i = half_thread_in_row; i > 0; i >>= 1) {
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, sum_, i);
sum_ += tmp;
}
return sum_;
}
CUTLASS_DEVICE
ElementSoftmaxCompute warp_reduce_max_(ElementSoftmaxCompute max_) {
int half_thread_in_row = (kThreadsPerRow >> 1);
CUTLASS_PRAGMA_UNROLL
for (int i = half_thread_in_row; i > 0; i >>= 1) {
ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, max_, i);
max_ = fast_max(max_, tmp);
}
return max_;
}
CUTLASS_DEVICE
void clear_accum_() {
uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX
float min_float = reinterpret_cast<float const &>(float_max_bits);
accum_max_ = ElementSoftmaxCompute(min_float);
accum_sum_ = ElementSoftmaxCompute(0);
}
CUTLASS_DEVICE
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) {
ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
sum_ += ElementSoftmaxCompute(accum[i]);
}
return sum_;
}
CUTLASS_DEVICE
ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute sum_) {
// ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
sum_ += ElementSoftmaxCompute(accum[i]);
}
return sum_;
}
CUTLASS_DEVICE
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) {
ElementSoftmaxCompute max_ = accum[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < SoftmaxFragment::kElements; ++i) {
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
}
return max_;
}
CUTLASS_DEVICE
ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < SoftmaxFragment::kElements; ++i) {
max_ = fast_max(max_, ElementSoftmaxCompute(accum[i]));
}
return max_;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -39,11 +39,12 @@
#pragma once
#include <utility>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#include <cuda/std/utility>
#else
#include <assert.h>
#include <utility>
#endif
#include "cutlass/cutlass.h"

View File

@ -121,6 +121,7 @@ public:
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(
int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
@ -128,7 +129,7 @@ public:
}
/// Called at the start of a row
/// Called at the end of a row
CUTLASS_DEVICE
void end_row(int row_idx) {
@ -325,6 +326,7 @@ public:
}
visitor.visit(
iter_idx,
row_idx,
col_idx,
idx,

View File

@ -391,10 +391,10 @@ struct OutputTileOptimalThreadMap {
1>;
/// Initial offset function
CUTLASS_HOST_DEVICE
CUTLASS_DEVICE
static MatrixCoord initial_offset(int thread_idx) {
int warp_idx = thread_idx / kWarpSize;
int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0);
int lane_idx = thread_idx % kWarpSize;
// Compute warp location
@ -419,7 +419,7 @@ struct OutputTileOptimalThreadMap {
return MatrixCoord(
cluster_offset + group_offset + row_offset + lane_row_offset,
(column_offset + lane_col_offset) * kElementsPerAccess
column_offset + lane_col_offset * kElementsPerAccess
);
}
@ -461,10 +461,10 @@ struct OutputTileOptimalThreadMap {
static int const kThreads = Threads;
/// Function to compute each thread's initial offset
CUTLASS_HOST_DEVICE
CUTLASS_DEVICE
static MatrixCoord initial_offset(int thread_idx) {
int warp_idx = thread_idx / kWarpSize;
int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0);
int lane_idx = thread_idx % kWarpSize;
// Compute warp location
@ -489,7 +489,7 @@ struct OutputTileOptimalThreadMap {
MatrixCoord coord(
cluster_offset + group_offset + row_offset + lane_row_offset,
(column_offset + lane_col_offset) * kElementsPerAccess
column_offset + lane_col_offset * kElementsPerAccess
);
return coord;

View File

@ -43,6 +43,7 @@
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/permute.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
@ -70,6 +71,7 @@ template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not
bool UseCUDAStore = false
>
class PredicatedTileIterator {
@ -173,9 +175,12 @@ private:
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer
/// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load().
uint8_t *byte_pointer_;
/// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_.
uint8_t *store_byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
@ -196,6 +201,11 @@ private:
/// Scatter indices
int const *indices_;
/// Whether to perform Permute Op
bool PermuteD;
/// PermuteDLayout
mutable PermuteDLayout permute_layout_;
//
// Static asserts about internal strides
@ -255,7 +265,7 @@ public:
mask_.clear();
}
// Initialize pointer
// Initialize byte_pointer_
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
@ -265,6 +275,19 @@ public:
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
}
// store_byte_pointer_ is set to be the same with byte_pointer_ unless PermuteD is used.
store_byte_pointer_ = byte_pointer_;
// Initialize PermuteD. If PermuteD is true, store_byte_pointer_ is initialized accordingly.
if (platform::is_same<PermuteDLayout, layout::NoPermute>::value) {
PermuteD = false;
}else{
PermuteD = true;
store_byte_pointer_ = reinterpret_cast<uint8_t *>(pointer);
permute_layout_ = PermuteDLayout(extent,
params_.stride * kElementsPerAccess / sizeof(AccessType));
}
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
@ -272,6 +295,7 @@ public:
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
store_byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
@ -353,7 +377,7 @@ public:
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const {
uint8_t *byte_pointer = byte_pointer_;
uint8_t *byte_pointer = store_byte_pointer_;
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
CUTLASS_PRAGMA_UNROLL
@ -388,21 +412,38 @@ public:
bool guard = row_guard && mask_.predicates[column];
int col_offset = column * ThreadMap::Delta::kColumn;
if (PermuteD) {
int col = col_offset + thread_start_column_;
int row = row_offset + thread_start_row_;
TensorCoord init_coord(row, col);
// Locate memory_pointer
memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset
+ permute_layout_(init_coord) * sizeof(AccessType) / kElementsPerAccess);
}
if (UseCUDAStore) {
if (guard) {
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
memory_pointer[0] =
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
}
} else {
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
(void *)&memory_pointer[0],
guard);
}
if (!PermuteD) {
memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
if (!ScatterD && !PermuteD) {
byte_pointer += params_.increment_row;
}
}
@ -605,6 +646,10 @@ public:
++state_[0];
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += params_.advance_row;
}
if (!ScatterD) {
byte_pointer_ += params_.advance_row;
}
@ -616,6 +661,7 @@ public:
state_[0] = 0;
++state_[1];
byte_pointer_ += params_.advance_group;
store_byte_pointer_ += params_.advance_group;
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
@ -625,6 +671,7 @@ public:
state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
store_byte_pointer_ += params_.advance_group;
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
@ -632,6 +679,7 @@ public:
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
store_byte_pointer_ += params_.advance_group;
}
}
}

View File

@ -35,6 +35,8 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/matrix.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -908,3 +908,4 @@ T absolute_value(T x) {
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,478 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Base device-level grouped kernel.
*/
#pragma once
#include <limits>
#include <numeric>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template <typename BaseKernel_>
class BaseGrouped {
public:
using BaseKernel = BaseKernel_;
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
/// Argument structure
using Arguments = typename BaseKernel::Arguments;
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
protected:
/// Kernel parameters object
typename BaseKernel::Params params_;
private:
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) {
int32_t tiles = 0;
for (int32_t i = 0; i < problem_count; ++i) {
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
tiles += problem_tile_count(problem);
}
return tiles;
}
/// Copy from `data` to `workspace`
Status copy_to_workspace(void* workspace, void* data, size_t bytes) {
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
if (cuda_error != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
cuda_error = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaMemcpy() returned error "
<< cudaGetErrorString(cuda_error));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Precomputes scheduling information for the grouped GEMM
Status precompute(Arguments const &args, int32_t tile_count, void* workspace) {
size_t workspace_bytes = get_workspace_size(args);
std::vector<uint8_t> host_workspace(workspace_bytes);
BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes,
args.problem_count,
args.threadblock_count,
(void*)host_workspace.data());
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
}
/// Reorder `data` according to `indices`
template <typename T>
static void reorder_array(T* data, const std::vector<size_t>& indices) {
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(indices.size());
for (int i = 0; i < indices.size(); ++i) {
copy.at(i) = data[indices[i]];
}
memcpy(data, copy.data(), indices.size() * sizeof(T));
}
public:
/// Constructs the GEMM.
BaseGrouped() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return BaseKernel::can_implement(args);
}
/// Get the number of tiles in a problem
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const &problem) {
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
return BaseKernel::ProblemVisitor::tile_count(grid);
}
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(Arguments const &args) {
if (args.host_problem_sizes == nullptr) {
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
return -1;
}
return group_tile_count(args.host_problem_sizes, args.problem_count);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
return BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes,
args.problem_count,
args.threadblock_count);
} else {
return 0;
}
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return dim3(args.threadblock_count, 1, 1);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<BaseKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error "
<< cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel<BaseKernel>,
BaseKernel::kThreadCount,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Sorts each pointer passed in according to the indices that sort
/// `problem_sizes_ptr` in descending order of problem-K dimension.
static void sort_problems(int problem_count,
cutlass::gemm::GemmCoord* problem_sizes_ptr,
int64_t* lda_host_ptr,
int64_t* ldb_host_ptr,
int64_t* ldc_host_ptr,
int64_t* ldd_host_ptr,
int64_t* offset_A_ptr,
int64_t* offset_B_ptr,
int64_t* offset_C_ptr,
int64_t* offset_D_ptr)
{
std::vector<size_t> indices(problem_count);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(), indices.end(),
[&problem_sizes_ptr](size_t i, size_t j) {
return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k();
});
reorder_array(problem_sizes_ptr, indices);
reorder_array(lda_host_ptr, indices);
reorder_array(ldb_host_ptr, indices);
reorder_array(ldc_host_ptr, indices);
reorder_array(ldd_host_ptr, indices);
reorder_array(offset_A_ptr, indices);
reorder_array(offset_B_ptr, indices);
reorder_array(offset_C_ptr, indices);
reorder_array(offset_D_ptr, indices);
}
/// Computes the number of threadblocks to launch for the grouped kernel
static int sufficient(const cutlass::gemm::GemmCoord* problem_sizes_ptr=nullptr,
int problem_count=0,
int available_sm_count=-1) {
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
cudaDeviceProp properties;
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error "
<< cudaGetErrorString(result));
return 0;
}
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error "
<< cudaGetErrorString(result));
return 0;
}
bool override_sm_count = (available_sm_count < 0 || available_sm_count > properties.multiProcessorCount);
if (override_sm_count) {
available_sm_count = properties.multiProcessorCount;
}
int max_active_blocks = maximum_active_blocks();
if (max_active_blocks <= 0) {
return 0;
}
int occupancy_based_block_count = available_sm_count * max_active_blocks;
if (problem_sizes_ptr == nullptr || problem_count == 0) {
return occupancy_based_block_count;
}
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
// If the group contains a single problem, launching the exact number of
// threadblocks needed to cover the problem minimizes the work performed
// per threadblock in finding the next tile to compute. We return total_tiles
// unless the user has provided the SM count.
if (problem_count == 1 && override_sm_count) {
return total_tiles;
}
// Choose between the full wave of threadblocks and the tile count. If there
// are fewer tiles in the group than threadblocks in the full wave, only
// some threadblocks will be assigned tiles. Those threadblocks
// which are not assigned tiles still need to perform the work of iterating through
// problem sizes to determine that they have no work to do. This competes for cycles
// with those threadblocks that are assigned tiles to compute.
return min(total_tiles, occupancy_based_block_count);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Workspace
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess) {
return status;
}
params_ = typename BaseKernel::Params(args, workspace, tile_count);
} else {
params_ = typename BaseKernel::Params(args, workspace);
}
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<BaseKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess) {
return status;
}
params_.update(args, workspace, tile_count);
} else {
params_.update(args, workspace);
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
//
// Configure grid and block dimensions
//
if (!params_.problem_visitor.problem_count) {
return Status::kSuccess;
}
dim3 grid(params_.threadblock_count, 1, 1);
dim3 block(BaseKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
//
// Launch kernel
//
// Launch
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Initializes and runs the kernel.
Status operator()(
Arguments const &args,
void *workspace,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

Some files were not shown because too many files have changed in this diff Show More