Files
cutlass/tools/util/include/cutlass/util/reference/device/convolution.h
ANIKET SHIVAM 66d9cddc83 New updates for 2.11 (#775)
* New updates.

* Minor profiler updates

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-01-20 16:32:57 -05:00

1550 lines
47 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for convolution in device-side code.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/conv3d_problem_size.h"
namespace cutlass {
namespace reference {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace kernel {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Conv2d device reference kernel
////////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2d Fprop kernel - y = fprop(x, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>,
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
int kCtaShapeM = 16, // shape of a threadblock in units of threads
int kCtaShapeN = 8 // shape of a threadblock in units of threads
>
__global__ void Conv2dFprop(
conv::Conv2dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_x,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_y_in,
TensorRef<ElementC, LayoutC> tensor_y_out,
ElementCompute alpha,
ElementCompute beta
) {
ConvertOp convert_op;
InnerProductOp inner_product_op;
ElementAccumulator element_A[kThreadM];
ElementAccumulator element_B[kThreadN];
ElementAccumulator accum[kThreadM][kThreadN];
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
int thread_n[kThreadM];
int thread_p[kThreadM];
int thread_q[kThreadM];
// Compute N, P, Q coordinates for each row of a thread's tile
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int64_t npq = npq_start + m;
thread_n[m] = int(npq / PQ);
int64_t residual = npq % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
}
// Clear accumulators
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = ElementAccumulator();
}
}
int c_per_group = problem_size.C / problem_size.groups;
int k_per_group = problem_size.K / problem_size.groups;
// Compute convolution
for (int R = 0; R < problem_size.R; ++R) {
for (int S = 0; S < problem_size.S; ++S) {
for (int C = 0; C < problem_size.C; ++C) {
// Get group id of currnet channel
int c_group_idx = C / c_per_group;
// Load from activations tensor
int filter_r = R;
int filter_s = S;
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
filter_r = problem_size.R - 1 - R;
filter_s = problem_size.S - 1 - S;
}
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C}));
}
else {
element_A[m] = ElementAccumulator();
}
}
// Load from filters tensor
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_k = k_start + n;
int k_group_idx = thread_k / k_per_group;
if (thread_k < problem_size.K && k_group_idx == c_group_idx) {
element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group}));
}
else {
element_B[n] = ElementAccumulator();
}
}
// Accumulate matrix product
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
}
}
}
}
}
// Write out the results
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_k = k_start + n;
if (thread_k < problem_size.K) {
ElementCompute c_ref = ElementCompute();
if (beta != ElementCompute()) {
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
}
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
}
}
}
}
}
// Conv3d Fprop kernel - y = fprop(x, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>,
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
int kCtaShapeM = 16, // shape of a threadblock in units of threads
int kCtaShapeN = 8 // shape of a threadblock in units of threads
>
__global__ void Conv3dFprop(
conv::Conv3dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_x,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_y_in,
TensorRef<ElementC, LayoutC> tensor_y_out,
ElementCompute alpha,
ElementCompute beta
) {
ConvertOp convert_op;
InnerProductOp inner_product_op;
ElementAccumulator element_A[kThreadM];
ElementAccumulator element_B[kThreadN];
ElementAccumulator accum[kThreadM][kThreadN];
int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
int thread_n[kThreadM];
int thread_z[kThreadM];
int thread_p[kThreadM];
int thread_q[kThreadM];
// Compute N, Z, P, Q coordinates for each row of a thread's tile
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
int64_t ZPQ = PQ * problem_size.Z;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int64_t nzpq = nzpq_start + m;
thread_n[m] = int(nzpq / ZPQ);
int64_t residual = nzpq % ZPQ;
thread_z[m] = int(residual / PQ);
residual = residual % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
}
// Clear accumulators
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = ElementAccumulator();
}
}
// Compute convolution
for (int T = 0; T < problem_size.T; ++T) {
for (int R = 0; R < problem_size.R; ++R) {
for (int S = 0; S < problem_size.S; ++S) {
for (int C = 0; C < problem_size.C; ++C) {
// Load from activations tensor
int filter_t = T;
int filter_r = R;
int filter_s = S;
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
filter_t = problem_size.T - 1 - T;
filter_r = problem_size.R - 1 - R;
filter_s = problem_size.S - 1 - S;
}
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
if (thread_n[m] < problem_size.N &&
d >= 0 && d < problem_size.D &&
h >= 0 && h < problem_size.H &&
w >= 0 && w < problem_size.W) {
element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C}));
}
else {
element_A[m] = ElementAccumulator();
}
}
// Load from filters tensor
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_k = k_start + n;
if (thread_k < problem_size.K) {
element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C}));
}
else {
element_B[n] = ElementAccumulator();
}
}
// Accumulate matrix product
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
}
}
} // for (C)
} // for (S)
} // for (R)
} // for (T)
// Write out the results
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
if (thread_n[m] < problem_size.N &&
thread_z[m] < problem_size.Z &&
thread_p[m] < problem_size.P &&
thread_q[m] < problem_size.Q) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_k = k_start + n;
if (thread_k < problem_size.K) {
ElementCompute c_ref = ElementCompute();
if (beta != ElementCompute()) {
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}));
}
tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
}
} // for (n)
}
} // for (m)
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2d dgrad kernel - dx = dgrad(dy, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>,
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
int kCtaShapeM = 16, // shape of a threadblock in units of threads
int kCtaShapeN = 8 // shape of a threadblock in units of threads
>
__global__ void Conv2dDgrad(
conv::Conv2dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_dx_in,
TensorRef<ElementC, LayoutC> tensor_dx_out,
ElementCompute alpha,
ElementCompute beta
) {
ConvertOp convert_op;
InnerProductOp inner_product_op;
ElementAccumulator element_A[kThreadM];
ElementAccumulator element_B[kThreadN];
ElementAccumulator accum[kThreadM][kThreadN];
int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
int thread_n[kThreadM];
int thread_h[kThreadM];
int thread_w[kThreadM];
// Compute N, H, W coordinates for each row of a thread's tile
int64_t HW = int64_t(problem_size.H) * problem_size.W;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int64_t nhw = nhw_start + m;
thread_n[m] = int(nhw / HW);
int64_t residual = nhw % HW;
thread_h[m] = int(residual / problem_size.W);
thread_w[m] = int(residual % problem_size.W);
}
// Clear accumulators
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = ElementAccumulator();
}
}
// Compute convolution
for (int R = 0; R < problem_size.R; ++R) {
for (int S = 0; S < problem_size.S; ++S) {
for (int K = 0; K < problem_size.K; ++K) {
// Load from activations tensor
int filter_r = R;
int filter_s = S;
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
filter_r = problem_size.R - 1 - R;
filter_s = problem_size.S - 1 - S;
}
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
element_A[m] = ElementAccumulator();
if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) {
p = p / problem_size.stride_h;
q = q / problem_size.stride_w;
if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) {
element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K}));
}
}
}
// Load from filters tensor
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_c = c_start + n;
if (thread_c < problem_size.C) {
element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c}));
}
else {
element_B[n] = ElementAccumulator();
}
}
// Accumulate matrix product
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
}
}
}
}
}
// Write out the results
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_c = c_start + n;
if (thread_c < problem_size.C) {
ElementCompute c_ref = ElementCompute();
if (beta != ElementCompute()) {
c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c}));
}
tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
}
}
}
}
}
// Conv3d dgrad kernel - dx = dgrad(dy, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>,
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
int kCtaShapeM = 16, // shape of a threadblock in units of threads
int kCtaShapeN = 8 // shape of a threadblock in units of threads
>
__global__ void Conv3dDgrad(
conv::Conv3dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_dx_in,
TensorRef<ElementC, LayoutC> tensor_dx_out,
ElementCompute alpha,
ElementCompute beta
) {
ConvertOp convert_op;
InnerProductOp inner_product_op;
ElementAccumulator element_A[kThreadM];
ElementAccumulator element_B[kThreadN];
ElementAccumulator accum[kThreadM][kThreadN];
int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
int thread_n[kThreadM];
int thread_d[kThreadM];
int thread_h[kThreadM];
int thread_w[kThreadM];
// Compute N, H, W coordinates for each row of a thread's tile
int64_t HW = int64_t(problem_size.H) * problem_size.W;
int64_t DHW = HW * problem_size.D;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int64_t ndhw = ndhw_start + m;
thread_n[m] = int(ndhw / DHW);
int64_t residual = ndhw % DHW;
thread_d[m] = int(residual / HW);
residual = residual % HW;
thread_h[m] = int(residual / problem_size.W);
thread_w[m] = int(residual % problem_size.W);
}
// Clear accumulators
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = ElementAccumulator();
}
}
// Compute convolution
for (int T = 0; T < problem_size.T; ++T) {
for (int R = 0; R < problem_size.R; ++R) {
for (int S = 0; S < problem_size.S; ++S) {
for (int K = 0; K < problem_size.K; ++K) {
// Load from activations tensor
int filter_t = T;
int filter_r = R;
int filter_s = S;
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
filter_t = problem_size.T - 1 - T;
filter_r = problem_size.R - 1 - R;
filter_s = problem_size.S - 1 - S;
}
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d;
int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
element_A[m] = ElementAccumulator();
if (z >= 0 && !(z % problem_size.stride_d) &&
p >= 0 && !(p % problem_size.stride_h) &&
q >= 0 && !(q % problem_size.stride_w)) {
z = z / problem_size.stride_d;
p = p / problem_size.stride_h;
q = q / problem_size.stride_w;
if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K}));
}
}
}
// Load from filters tensor
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_c = c_start + n;
if (thread_c < problem_size.C) {
element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c}));
}
else {
element_B[n] = ElementAccumulator();
}
}
// Accumulate matrix product
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
}
}
} // for (C)
} // for (S)
} // for (R)
} // for (T)
// Write out the results
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
if (thread_n[m] < problem_size.N &&
thread_d[m] < problem_size.D &&
thread_h[m] < problem_size.H &&
thread_w[m] < problem_size.W) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_c = c_start + n;
if (thread_c < problem_size.C) {
ElementCompute c_ref = ElementCompute();
if (beta != ElementCompute()) {
c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}));
}
tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
}
}
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2d wgrad kernel - dw = wgrad(dy, x)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>,
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
int kCtaShapeM = 8, // shape of a threadblock in units of threads
int kCtaShapeN = 16 // shape of a threadblock in units of threads
>
__global__ void Conv2dWgrad(
conv::Conv2dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_x,
TensorRef<ElementC, LayoutC> tensor_dw_in,
TensorRef<ElementC, LayoutC> tensor_dw_out,
ElementCompute alpha,
ElementCompute beta
) {
ConvertOp convert_op;
InnerProductOp inner_product_op;
ElementAccumulator element_A[kThreadM];
ElementAccumulator element_B[kThreadN];
ElementAccumulator accum[kThreadM][kThreadN];
int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
int thread_r[kThreadN];
int thread_s[kThreadN];
int thread_c[kThreadN];
// Compute R, S, C coordinates for each row of a thread's tile
int64_t SC = int64_t(problem_size.S) * problem_size.C;
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int64_t rsc = rsc_start + n;
int64_t residual = rsc % SC;
thread_r[n] = int(rsc / SC);
thread_s[n] = int(residual / problem_size.C);
thread_c[n] = int(residual % problem_size.C);
}
// Clear accumulators
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = ElementAccumulator();
}
}
// Compute convolution
for (int N = 0; N < problem_size.N; ++N) {
for (int P = 0; P < problem_size.P; ++P) {
for (int Q = 0; Q < problem_size.Q; ++Q) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int thread_k = k_start + m;
element_A[m] = ElementAccumulator();
if (thread_k < problem_size.K) {
element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k}));
}
}
// Load from filters tensor
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
// Load from activations tensor
int filter_r = thread_r[n];
int filter_s = thread_s[n];
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
filter_r = problem_size.R - 1 - filter_r;
filter_s = problem_size.S - 1 - filter_s;
}
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
element_B[n] = ElementAccumulator();
if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) {
element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]}));
}
}
// Accumulate matrix product
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
}
}
}
}
}
// Write out the results
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int thread_k = k_start + m;
if (thread_k < problem_size.K) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) {
ElementCompute c_ref = ElementCompute();
if (beta != ElementCompute()) {
c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}));
}
tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
}
}
}
}
}
// Conv3d wgrad kernel - dw = wgrad(dy, x)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>,
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
int kCtaShapeM = 8, // shape of a threadblock in units of threads
int kCtaShapeN = 16 // shape of a threadblock in units of threads
>
__global__ void Conv3dWgrad(
conv::Conv3dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_x,
TensorRef<ElementC, LayoutC> tensor_dw_in,
TensorRef<ElementC, LayoutC> tensor_dw_out,
ElementCompute alpha,
ElementCompute beta
) {
ConvertOp convert_op;
InnerProductOp inner_product_op;
ElementAccumulator element_A[kThreadM];
ElementAccumulator element_B[kThreadN];
ElementAccumulator accum[kThreadM][kThreadN];
int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
int thread_t[kThreadN];
int thread_r[kThreadN];
int thread_s[kThreadN];
int thread_c[kThreadN];
// Compute R, S, C coordinates for each row of a thread's tile
int64_t SC = int64_t(problem_size.S) * problem_size.C;
int64_t RSC = SC * problem_size.R;
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int64_t trsc = trsc_start + n;
thread_t[n] = int(trsc / RSC);
int64_t residual = trsc % RSC;
thread_r[n] = int(residual / SC);
residual = residual % SC;
thread_s[n] = int(residual / problem_size.C);
thread_c[n] = int(residual % problem_size.C);
}
// Clear accumulators
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = ElementAccumulator();
}
}
// Compute convolution
for (int N = 0; N < problem_size.N; ++N) {
for (int Z = 0; Z < problem_size.Z; ++Z) {
for (int P = 0; P < problem_size.P; ++P) {
for (int Q = 0; Q < problem_size.Q; ++Q) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int thread_k = k_start + m;
element_A[m] = ElementAccumulator();
if (thread_k < problem_size.K) {
element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k}));
}
}
// Load from filters tensor
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
// Load from activations tensor
int filter_t = thread_t[n];
int filter_r = thread_r[n];
int filter_s = thread_s[n];
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
filter_t = problem_size.T - 1 - filter_t;
filter_r = problem_size.R - 1 - filter_r;
filter_s = problem_size.S - 1 - filter_s;
}
int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d;
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
element_B[n] = ElementAccumulator();
if (d >= 0 && d < problem_size.D &&
h >= 0 && h < problem_size.H &&
w >= 0 && w < problem_size.W &&
thread_c[n] < problem_size.C) {
element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]}));
}
}
// Accumulate matrix product
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
}
}
} // for (Q)
} // for (P)
} // for (Z)
} // for (N)
// Write out the results
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int thread_k = k_start + m;
if (thread_k < problem_size.K) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
if (thread_t[n] < problem_size.T &&
thread_r[n] < problem_size.R &&
thread_s[n] < problem_size.S &&
thread_c[n] < problem_size.C) {
ElementCompute c_ref = ElementCompute();
if (beta != ElementCompute()) {
c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}));
}
tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
}
}
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Conv2d Fprop dispatcher - y = fprop(x, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv2dFprop(
conv::Conv2dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_x,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_y_in,
TensorRef<ElementC, LayoutC> tensor_y_out,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
//
// Blocking factors improve performance of reference implementation
//
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
dim3 block(kCtaShapeM, kCtaShapeN);
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
kernel::Conv2dFprop<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp,
kThreadM,
kThreadN,
kCtaShapeM,
kCtaShapeN
><<< grid, block, 0, stream >>>(
problem_size,
tensor_x,
tensor_w,
tensor_y_in,
tensor_y_out,
alpha,
beta
);
cudaError_t result = cudaPeekAtLastError();
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Conv3d Fprop dispatcher - y = fprop(x, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv3dFprop(
conv::Conv3dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_x,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_y_in,
TensorRef<ElementC, LayoutC> tensor_y_out,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
//
// Blocking factors improve performance of reference implementation
//
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q;
int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
dim3 block(kCtaShapeM, kCtaShapeN);
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
kernel::Conv3dFprop<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp,
kThreadM,
kThreadN,
kCtaShapeM,
kCtaShapeN
><<< grid, block, 0, stream >>>(
problem_size,
tensor_x,
tensor_w,
tensor_y_in,
tensor_y_out,
alpha,
beta
);
cudaError_t result = cudaPeekAtLastError();
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv2dDgrad(
conv::Conv2dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_dx_in,
TensorRef<ElementC, LayoutC> tensor_dx_out,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
//
// Blocking factors improve performance of reference implementation
//
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W;
int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
dim3 block(kCtaShapeM, kCtaShapeN);
dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
kernel::Conv2dDgrad<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp,
kThreadM,
kThreadN,
kCtaShapeM,
kCtaShapeN
><<< grid, block, 0, stream >>>(
problem_size,
tensor_dy,
tensor_w,
tensor_dx_in,
tensor_dx_out,
alpha,
beta
);
cudaError_t result = cudaPeekAtLastError();
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv3dDgrad(
conv::Conv3dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_w,
TensorRef<ElementC, LayoutC> tensor_dx_in,
TensorRef<ElementC, LayoutC> tensor_dx_out,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
//
// Blocking factors improve performance of reference implementation
//
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W;
int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
dim3 block(kCtaShapeM, kCtaShapeN);
dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
kernel::Conv3dDgrad<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp,
kThreadM,
kThreadN,
kCtaShapeM,
kCtaShapeN
><<< grid, block, 0, stream >>>(
problem_size,
tensor_dy,
tensor_w,
tensor_dx_in,
tensor_dx_out,
alpha,
beta
);
cudaError_t result = cudaPeekAtLastError();
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv2dWgrad(
conv::Conv2dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_x,
TensorRef<ElementC, LayoutC> tensor_dw_in,
TensorRef<ElementC, LayoutC> tensor_dw_out,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
//
// Blocking factors improve performance of reference implementation
//
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
int const kCtaShapeM = 8; // shape of a threadblock in units of threads
int const kCtaShapeN = 16; // shape of a threadblock in units of threads
int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C;
int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
dim3 block(kCtaShapeM, kCtaShapeN);
dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
kernel::Conv2dWgrad<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp,
kThreadM,
kThreadN,
kCtaShapeM,
kCtaShapeN
><<< grid, block, 0, stream >>>(
problem_size,
tensor_dy,
tensor_x,
tensor_dw_in,
tensor_dw_out,
alpha,
beta
);
cudaError_t result = cudaPeekAtLastError();
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x)
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv3dWgrad(
conv::Conv3dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_dy,
TensorRef<ElementB, LayoutB> tensor_x,
TensorRef<ElementC, LayoutC> tensor_dw_in,
TensorRef<ElementC, LayoutC> tensor_dw_out,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
//
// Blocking factors improve performance of reference implementation
//
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
int const kCtaShapeM = 8; // shape of a threadblock in units of threads
int const kCtaShapeN = 16; // shape of a threadblock in units of threads
int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C;
int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
dim3 block(kCtaShapeM, kCtaShapeN);
dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
kernel::Conv3dWgrad<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp,
kThreadM,
kThreadN,
kCtaShapeM,
kCtaShapeN
><<< grid, block, 0, stream >>>(
problem_size,
tensor_dy,
tensor_x,
tensor_dw_in,
tensor_dw_out,
alpha,
beta
);
cudaError_t result = cudaPeekAtLastError();
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv2d(
conv::Operator convolutional_operator,
conv::Conv2dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_A,
TensorRef<ElementB, LayoutB> tensor_B,
TensorRef<ElementC, LayoutC> tensor_C,
TensorRef<ElementC, LayoutC> tensor_D,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
switch (convolutional_operator) {
case conv::Operator::kFprop:
return Conv2dFprop<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp, InnerProductOp
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
break;
case conv::Operator::kDgrad:
return Conv2dDgrad<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp, InnerProductOp
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
break;
case conv::Operator::kWgrad:
return Conv2dWgrad<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp, InnerProductOp
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
break;
default: break;
}
return Status::kErrorNotSupported;
}
/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator = ElementCompute,
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
typename InnerProductOp = multiply_add<ElementAccumulator>
>
Status Conv3d(
conv::Operator convolutional_operator,
conv::Conv3dProblemSize problem_size,
TensorRef<ElementA, LayoutA> tensor_A,
TensorRef<ElementB, LayoutB> tensor_B,
TensorRef<ElementC, LayoutC> tensor_C,
TensorRef<ElementC, LayoutC> tensor_D,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr) {
switch (convolutional_operator) {
case conv::Operator::kFprop:
return Conv3dFprop<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp, InnerProductOp
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
case conv::Operator::kDgrad:
return Conv3dDgrad<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp, InnerProductOp
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
case conv::Operator::kWgrad:
return Conv3dWgrad<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp, InnerProductOp
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
default: break;
}
return Status::kErrorNotSupported;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace reference
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////