releaase 2.11 (#703)

This commit is contained in:
Aditya Atluri
2022-11-19 06:02:15 -08:00
committed by GitHub
parent 3c90f6aea6
commit c975e2ccbb
329 changed files with 47332 additions and 10607 deletions

View File

@ -80,9 +80,9 @@ public:
typedef value_type *pointer;
typedef value_type const * const_pointer;
using ArrayType = Array<T, N>;
using reference = typename ArrayType::reference;
using const_reference = typename ArrayType::const_reference;
using Array = Array<T, N>;
using reference = typename Array::reference;
using const_reference = typename Array::const_reference;
public:

View File

@ -85,6 +85,10 @@ struct Sm86 {
static int const kMinComputeCapability = 86;
};
struct Sm90 {
static int const kMinComputeCapability = 90;
};
/// Triggers a breakpoint on the device
CUTLASS_DEVICE
void device_breakpoint() {

View File

@ -451,7 +451,7 @@ template <>
CUTLASS_DEVICE
void shared_store<16>(uint32_t ptr, void const *src) {
uint4 const *dst_u128 = reinterpret_cast<uint4 const *>(src);
asm volatile("ld.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
: :
"r"(ptr),
"r"(dst_u128->x),

View File

@ -223,4 +223,6 @@ struct SparseMma;
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sparse_sm80.h"
#include "cutlass/arch/mma_sm90.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1065,7 +1065,7 @@ struct Mma<
int const *C = reinterpret_cast<int const *>(&c);
int *D = reinterpret_cast<int *>(&d);
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
asm volatile("_mma.m8n8k32.row.col.u4.s4.sat {%0,%1}, %2, %3, {%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
@ -1247,7 +1247,8 @@ struct Mma<
) const {
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED))
using WmmaFragmentA = nvcuda::wmma::fragment<
nvcuda::wmma::matrix_a,
Shape::kM,
@ -1279,6 +1280,7 @@ struct Mma<
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
#else
CUTLASS_UNUSED(a);
@ -1289,14 +1291,7 @@ struct Mma<
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
#else
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_UNUSED(d);
assert(0);
#endif
}
};

View File

@ -2156,6 +2156,7 @@ struct Mma<
int const *C = reinterpret_cast<int const *>(&c);
int *D = reinterpret_cast<int *>(&d);
asm volatile(
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, "
"{%4,%5,%6,%7}, "

View File

@ -0,0 +1,131 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Matrix multiply
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "mma.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8))
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#define CUTLASS_ARCH_MMA_SM90_ENABLED
#endif
#endif
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace arch {
////////////////////////////////////////////////////////////////////////////////
/// Matrix Multiply-Add 16x8x4 fp64
////////////////////////////////////////////////////////////////////////////////
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
template <>
struct Mma<
gemm::GemmShape<16,8,4>,
32,
double,
layout::RowMajor,
double,
layout::ColumnMajor,
double,
layout::RowMajor,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<16,8,4>;
using ElementA = double;
using LayoutA = layout::RowMajor;
using FragmentA = Array<double, 2>;
using ElementB = double;
using LayoutB = layout::ColumnMajor;
using FragmentB = Array<double, 1>;
using ElementC = double;
using LayoutC = layout::RowMajor;
using FragmentC = Array<double, 4>;
using Operator = OpMultiplyAdd;
using ArchTag = arch::Sm90;
CUTLASS_HOST_DEVICE
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
FragmentC const &c) const {
#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED)
double const *A = reinterpret_cast<double const *>(&a);
double const *B = reinterpret_cast<double const *>(&b);
double const *C = reinterpret_cast<double const *>(&c);
double *D = reinterpret_cast<double *>(&d);
asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
: "d"(A[0]), "d"(A[1]),
"d"(B[0]),
"d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
#else
CUTLASS_UNUSED(d);
CUTLASS_UNUSED(a);
CUTLASS_UNUSED(b);
CUTLASS_UNUSED(c);
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

201
include/cutlass/barrier.h Normal file
View File

@ -0,0 +1,201 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implementation of a CTA-wide barrier for inter-CTA synchronization.
*/
#pragma once
#include "cutlass/cutlass.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
struct Barrier
{
public:
/// Flag type
using T = int;
/// Initial flag value
static const T INIT = 0;
protected:
/// Load flag, as a strong operation (int specialization)
CUTLASS_DEVICE
static int ld_strong(int *ptr)
{
int state = 0;
#if (__CUDA_ARCH__ >= 700)
/// SM70 and newer use memory consistency qualifiers
asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
#else
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
#endif // (__CUDA_ARCH__ >= 700)
return state;
}
/// Store flag, as a strong operation (int specialization)
CUTLASS_DEVICE
static void st_strong(int *ptr, int val)
{
#if (__CUDA_ARCH__ >= 700)
/// SM70 and newer use memory consistency qualifiers
asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#endif // (__CUDA_ARCH__ >= 700)
}
/// Reduce into flag, with release pattern (int specialization)
CUTLASS_DEVICE
static void red_release(int *ptr, int val)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
#if (__CUDA_ARCH__ >= 700)
/// SM70 and newer use memory consistency qualifiers
asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
__threadfence();
atomicAdd(ptr, val);
#endif // (__CUDA_ARCH__ >= 700)
#endif
}
public:
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
CUTLASS_DEVICE
static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
if (thread_idx == 0)
{
// Spin-loop
#pragma unroll 1
while(ld_strong(flag_ptr) < count) {}
}
__syncthreads();
#endif
}
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
CUTLASS_DEVICE
static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
if (thread_idx == 0)
{
// Spin-loop
#pragma unroll 1
while(ld_strong(flag_ptr) != val) {}
}
__syncthreads();
#endif
}
/// Uses thread[0] to wait for the specified count of signals on the given flag counter
CUTLASS_DEVICE
static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
if (thread_idx == 0)
{
// Spin-loop
#pragma unroll 1
while(atomicCAS(flag_ptr, val, 0) != val) {}
}
__syncthreads();
#endif
}
/// Increment the arrival count for a flag
CUTLASS_DEVICE
static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
__syncthreads();
if (thread_idx == 0) {
red_release(flag_ptr, 1);
}
#endif
}
/// Increment the arrival counts for a range of flags
CUTLASS_DEVICE
static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1)
{
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
int flag_idx = first_flag_idx + thread_idx;
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
// Barrier to make sure all other threads in block have written their data
__syncthreads();
// Select threads increment their flags
if (thread_idx < count) {
red_release(flag_ptr, 1);
}
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -35,7 +35,9 @@
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#if defined(__CUDACC_RTC__)
#include "cutlass/floating_point_nvrtc.h"
#else
#include <cmath>
#include <limits>
#include <cstdint>
@ -71,8 +73,7 @@ struct alignas(2) bfloat16_t {
}
/// Default constructor
CUTLASS_HOST_DEVICE
bfloat16_t() : storage(0) { }
bfloat16_t() = default;
/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE

View File

@ -40,6 +40,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/complex.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_types.h"

View File

@ -0,0 +1,259 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable,
statically-sized array types to global memory.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/wmma_array.h"
#include "cutlass/functional.h"
#include "cutlass/complex.h"
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
// AccessWidth
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit
template <
typename T,
int Limit>
struct AccessWidth
{
// Inductive case
template <
int ObjectBytes, /// Size of T in bytes
int AlignBytes, /// Template induction variable
bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes
((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))>
struct Detail
{
static const int value = Detail<ObjectBytes, AlignBytes * 2>::value;
};
// Base case (ObjectBytes is not an even multiple of AlignBytes)
template <
int ObjectBytes, /// Size of T in bytes
int AlignBytes> /// Template induction variable
struct Detail<ObjectBytes, AlignBytes, false>
{
static const int value = AlignBytes / 2;
};
/// The maximal power-of-two that evenly divides the size of T
static const int value = Detail<
(int) sizeof(T),
1>::value;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// StripedAccessType
/////////////////////////////////////////////////////////////////////////////////////////////////
/// ReinterpretCast type for striping a trivially-copyable type in global memory
/// (Default specialization. Striping granularity is type T.)
template <
typename T, /// Data type
int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures)
AccessWidth<T, 16>::value>
struct alignas(TransferBytes) StripedAccessType : public T
{};
/// ReinterpretCast type for striping a trivially-copyable type in global memory
/// (Specialization for cutlass::Array<T>. Striping granularity is a multiple of T.)
template <
typename T, /// Array element type
int N, /// Number of elements in array
bool RegisterSized, /// T is register-sized
int TransferBytes> /// Data access width
struct StripedAccessType<
Array<T, N, RegisterSized>,
TransferBytes>
: public AlignedArray<
T, // Element type of StripedAccessType
__NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType
TransferBytes> // Alignment of StripedAccessType
{};
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
/// ReinterpretCast type for striping a trivially-copyable type in global memory
/// (Specialization for cutlass::WmmaFragmentArray<T>. Striping granularity is a multiple of T.)
template<
typename Use,
int m,
int n,
int k,
typename ElementT,
typename Layout,
int kFragments,
int TransferBytes>
struct StripedAccessType<
WmmaFragmentArray<nvcuda::wmma::fragment<Use, m, n, k, ElementT, Layout>, kFragments>,
TransferBytes>
: public AlignedArray<
ElementT,
__NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)),
TransferBytes>
{};
#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
/////////////////////////////////////////////////////////////////////////////////////////////////
// BlockStriped
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Utility for performing block-striped access (load, store) of trivially-copyable,
/// statically-sized array types to global memory
template <
int BlockThreads,
typename ArrayT,
typename T,
typename AccessT = StripedAccessType<ArrayT> >
struct BlockStriped
{
/// Number of striped accesses
static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT));
static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type");
/// Load
CUTLASS_DEVICE
static void load(ArrayT &data, T *ptr, int thread_idx)
{
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kStripes; ++i) {
access_data[i] = access_input[(BlockThreads * i) + thread_idx];
}
}
/// Load & Add
CUTLASS_DEVICE
static void load_add(ArrayT &data, T *ptr, int thread_idx)
{
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
plus<AccessT> add;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kStripes; ++i)
{
access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]);
}
}
/// Store
CUTLASS_DEVICE
static void store(T *ptr, const ArrayT &data, int thread_idx)
{
AccessT *access_output = reinterpret_cast<AccessT*>(ptr);
const AccessT *access_data = reinterpret_cast<const AccessT*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kStripes; ++i) {
access_output[(BlockThreads * i) + thread_idx] = access_data[i];
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// BlockStripedReduce
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
/// statically-sized array types to global memory.
/// (Default specialization)
template <
int BlockThreads,
typename ArrayT,
typename T>
struct BlockStripedReduce : BlockStriped<BlockThreads, ArrayT, T, T>
{
/// Reduce
CUTLASS_DEVICE
static void reduce(T *ptr, const ArrayT &data, int thread_idx)
{
cutlass::red<T> reduce;
const T *access_data = reinterpret_cast<const T*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < BlockStripedReduce::kStripes; ++i) {
reduce(ptr + (BlockThreads * i) + thread_idx, access_data[i]);
}
}
};
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
/// statically-sized array types to global memory.
/// (Specialization for half_t. Uses half2 vectorized-reduction.)
template <
int BlockThreads,
typename ArrayT>
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> : BlockStriped<BlockThreads, ArrayT, half_t, half2>
{
static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
/// Reduce
CUTLASS_DEVICE
static void reduce(half_t *ptr, const ArrayT &data, int thread_idx)
{
cutlass::red<half2> reduce;
half2 *access_output = reinterpret_cast<half2*>(ptr);
const half2 *access_data = reinterpret_cast<const half2*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < BlockStripedReduce::kStripes; ++i)
{
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
}
}
};
} // namespace cutlass

View File

@ -32,6 +32,8 @@
#include <cuComplex.h>
#include <cuda_fp16.h>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
@ -39,6 +41,7 @@
#endif
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"
#include "cutlass/real.h"
@ -53,8 +56,10 @@
namespace cutlass {
//////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Enumeraed type describing a transformation on a complex value.
enum class ComplexTransform {
kNone,
@ -147,15 +152,18 @@ class complex
// Methods
//
/// Constructor
CUTLASS_HOST_DEVICE
complex(T r = T(0)) : _real(r), _imag(T(0)) {}
/// Default constructor
complex() = default;
/// Constructor
/// Constructor
CUTLASS_HOST_DEVICE
complex(T r) : _real(r), _imag(T(0)) {}
/// Constructor
CUTLASS_HOST_DEVICE
complex(T r, T i) : _real(r), _imag(i) {}
//
/// Constructor
/// Constructor
template<typename A>
CUTLASS_HOST_DEVICE
complex(complex<A> const &z) : _real(static_cast<T>(z.real())), _imag(static_cast<T>(z.imag())) {}
@ -197,6 +205,24 @@ class complex
return complex<T>(this->real() + rhs.real(), this->imag() + rhs.imag());
}
/// Reduction into memory address. Components may update out of order.
template <typename OtherT>
CUTLASS_DEVICE void red(complex<OtherT> *ptr) const {
static_assert(platform::is_same<T, OtherT>::value, "Component type must match");
cutlass::red<T> reduce;
reduce(&ptr->_real, _real);
reduce(&ptr->_imag, _imag);
}
/// Reduction into memory address. Components may update out of order. (Half specialization)
CUTLASS_DEVICE void red(complex<half_t> *ptr) const {
static_assert(platform::is_same<T, half_t>::value, "Component type must match");
half2 *h2_ptr = reinterpret_cast<half2*>(ptr);
half2 h2_data = reinterpret_cast<half2&>(*this);
cutlass::red<half2> reduce;
reduce(h2_ptr, h2_data);
}
/// Subtraction
template <typename A>
CUTLASS_HOST_DEVICE complex<T> operator-(complex<A> const &rhs) const {
@ -506,13 +532,14 @@ CUTLASS_HOST_DEVICE bool operator<(complex<T> const &lhs, complex<T> const &rhs)
/// Partial specialization for complex-valued type.
template <typename T>
struct RealType< complex<T> > {
struct RealType< complex<T> >
{
using Type = T;
/// Number of elements
static int const kExtent = 2;
CUTLASS_HOST_DEVICE
CUTLASS_HOST_DEVICE
static complex<T> from_real(double x) {
return complex<T>(static_cast<T>(x));
}
@ -550,6 +577,127 @@ struct is_complex<complex<T>> {
static bool const value = true;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// functional.h numeric specializations
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Squares with optional conversion
template <typename T, typename Output>
struct magnitude_squared<complex<T>, Output> {
CUTLASS_HOST_DEVICE
Output operator()(complex<T> lhs) const {
multiplies<Output> mul_op;
Output y_r = Output(lhs.real());
Output y_i = Output(lhs.imag());
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
}
};
/// Fused multiply-add
template <typename T>
struct multiply_add<complex<T>, complex<T>, complex<T>> {
CUTLASS_HOST_DEVICE
complex<T> operator()(
complex<T> const &a,
complex<T> const &b,
complex<T> const &c) const {
T real = c.real();
T imag = c.imag();
real += a.real() * b.real();
real += -a.imag() * b.imag();
imag += a.real() * b.imag();
imag += a.imag () * b.real();
return complex<T>{
real,
imag
};
}
};
/// Fused multiply-add
template <typename T>
struct multiply_add<complex<T>, T, complex<T>> {
CUTLASS_HOST_DEVICE
complex<T> operator()(
complex<T> const &a,
T const &b,
complex<T> const &c) const {
T real = c.real();
T imag = c.imag();
real += a.real() * b;
imag += a.imag () * b;
return complex<T>{
real,
imag
};
}
};
/// Fused multiply-add
template <typename T>
struct multiply_add<T, complex<T>, complex<T>> {
CUTLASS_HOST_DEVICE
complex<T> operator()(
T const &a,
complex<T> const &b,
complex<T> const &c) const {
T real = c.real();
T imag = c.imag();
real += a * b.real();
imag += a * b.imag();
return complex<T>{
real,
imag
};
}
};
/// Conjugate
template <typename T>
struct conjugate<complex<T>> {
CUTLASS_HOST_DEVICE
complex<T> operator()(complex<T> const &a) const {
return conj(a);
}
};
/// Computes the square of a difference with optional conversion
template <typename T, typename Output>
struct magnitude_squared_difference<complex<T>, Output> {
CUTLASS_HOST_DEVICE
Output operator()(complex<T> lhs, complex<T> rhs) const {
multiplies<Output> mul_op;
Output y_r = Output(lhs.real()) - Output(rhs.real());
Output y_i = Output(lhs.imag()) - Output(rhs.imag());
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
}
};
/// Reduces value into the data pointed to by ptr (complex<T> specialization)
template <typename T>
struct red<complex<T>> {
CUTLASS_DEVICE
void operator()(complex<T> *ptr, const complex<T> &data)
{
data.red(ptr);
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -247,7 +247,7 @@ public:
CUTLASS_HOST_DEVICE
cutlass::Tensor4DCoord filter_extent() const {
return cutlass::Tensor4DCoord ({K, R, S, C});
return cutlass::Tensor4DCoord ({K, R, S, C / groups});
}
/// Returns output extent as Tensor4DCoord
@ -336,7 +336,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size(
return gemm::GemmCoord(
problem_size.N * problem_size.P * problem_size.Q,
problem_size.K,
problem_size.R * problem_size.S * problem_size.C
problem_size.R * problem_size.S * problem_size.C / problem_size.groups
);
case Operator::kDgrad:
return gemm::GemmCoord(
@ -451,6 +451,18 @@ int implicit_gemm_k_iterations(
default:
break;
}
} else if (algorithm == IteratorAlgorithm::kOptimized) {
// Current optimized iterator only support GroupMode::kSingleGroup
if (group_mode == GroupMode::kSingleGroup) {
switch (conv_operator) {
case Operator::kFprop:
iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
break;
default:
break;
}
}
}
}
@ -459,6 +471,25 @@ int implicit_gemm_k_iterations(
}
template <int N = 1, int Output_P = 1, int Output_Q = 1>
CUTLASS_HOST_DEVICE
int depthwise_gemm_k_iterations(
Operator conv_operator,
int threadblock_K,
Conv2dProblemSize const &problem_size,
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
GroupMode group_mode = GroupMode::kNone,
int threadblock_N = 0) {
int n = problem_size.N;
int p = (problem_size.P + Output_P - 1) / Output_P;
int q = (problem_size.Q + Output_Q - 1) / Output_Q;
int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
return iterations;
}
CUTLASS_HOST_DEVICE
int implicit_gemm_k_iterations_per_channel(
Operator conv_operator,

View File

@ -100,14 +100,16 @@ enum class IteratorAlgorithm {
kAnalytic, ///< functionally correct in all cases but lower performance
kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad
kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize)
kFewChannels ///< Analytic algorithm optimized for few channels (C divisible by AccessSize)
kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize)
kFixedStrideDilation ///< Optimized for fixed stride and dilation
};
/// Distinguishes among partial specializations that accelerate certain problems where convolution
/// stride is unit.
enum class StrideSupport {
kStrided, ///< arbitrary convolution stride
kUnity ///< unit convolution stride
kUnity, ///< unit convolution stride
kFixed ///< fixed convolution stride
};
/// Identifies split-K mode
@ -125,6 +127,38 @@ enum class GroupMode {
kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups)
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Shape of a tensor
template <
int N = 1,
int H = 1,
int W = 1,
int C = 1
>
struct TensorNHWCShape {
static int const kN = N;
static int const kH = H;
static int const kW = W;
static int const kC = C;
static int const kHW = H * W;
static int const kNHW = N * kHW;
static int const kNHWC = N * H * W * C;
static int const kCount = kNHWC;
//
// Static member functions
//
/// Returns a Coord object
CUTLASS_HOST_DEVICE
static Coord<4> toCoord() {
return make_Coord(kN, kH, kW, kC);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace conv

View File

@ -0,0 +1,269 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Template for device-level Depthwise Convolution
*/
#pragma once
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/conv/convolution.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template<typename DirectConvolutionKernel_>
class DirectConvolution {
public:
using UnderlyingKernel = DirectConvolutionKernel_;
using ElementA = typename UnderlyingKernel::ElementA;
using LayoutA = typename UnderlyingKernel::LayoutA;
using ElementB = typename UnderlyingKernel::ElementB;
using LayoutB = typename UnderlyingKernel::LayoutB;
using ElementC = typename UnderlyingKernel::ElementC;
using LayoutC = typename UnderlyingKernel::LayoutC;
using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
using ElementCompute = typename UnderlyingKernel::ElementCompute;
using OperatorClass = typename UnderlyingKernel::OperatorClass;
using ArchTag = typename UnderlyingKernel::ArchTag;
using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
using WarpShape = typename UnderlyingKernel::WarpShape;
using InstructionShape = typename UnderlyingKernel::InstructionShape;
using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
static int const kStages = UnderlyingKernel::kStages;
static int const kConvDim = UnderlyingKernel::kConvDim;
using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
using MathOperator = typename UnderlyingKernel::MathOperator;
static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
static int const kWarpCount =
(ThreadblockShape::kM / WarpShape::kM) *
(ThreadblockShape::kN / WarpShape::kN) *
(ThreadblockShape::kK / WarpShape::kK);
/// Argument structure
using Arguments = typename UnderlyingKernel::Arguments;
using ReorderKernel = typename UnderlyingKernel::ReorderKernel;
private:
/// Kernel parameters object
typename UnderlyingKernel::Params params_;
public:
/// Constructs Implicit GEMM
DirectConvolution() { }
/// Determines whether the Implicit GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
// dispatch to iterators
Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
if (Status::kSuccess != status) {
return status;
}
status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
if (Status::kSuccess != status) {
return status;
}
if (kGroupMode != conv::GroupMode::kDepthwise) {
return Status::kErrorInvalidProblem;
}
// C and K should be multiple of groups
if (args.problem_size.K != args.problem_size.groups &&
args.problem_size.C != args.problem_size.groups) {
return Status::kErrorInvalidProblem;
}
static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
if (kConvolutionalOperator == conv::Operator::kFprop) {
if (args.problem_size.K % kAlignmentC)
return Status::kErrorMisalignedOperand;
} else if (kConvolutionalOperator == conv::Operator::kDgrad) {
if (args.problem_size.C % kAlignmentC)
return Status::kErrorMisalignedOperand;
} else if (kConvolutionalOperator == conv::Operator::kWgrad) {
if (args.problem_size.C % kAlignmentC)
return Status::kErrorMisalignedOperand;
}
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(
threadblock_swizzle.get_tiled_shape(
kConvolutionalOperator,
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.problem_size.split_k_slices));
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
grid.z <= std::numeric_limits<uint16_t>::max())) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return 0;
}
/// Initializes GEMM state from arguments.
Status initialize(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
// initialize the params structure from the arguments
params_ = typename UnderlyingKernel::Params(
args,
static_cast<int *>(workspace)
);
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Initializes GEMM state from arguments.
Status update(Arguments const &args, void *workspace = nullptr) {
// update the params structure from the arguments
params_.ptr_A = args.ref_A.data();
params_.ptr_B = args.ref_B.data();
params_.ptr_C = args.ref_C.data();
params_.ptr_D = args.ref_D.data();
params_.output_op = args.output_op;
params_.ptr_reordered_B = args.ref_reordered_B.data();;
params_.semaphore = static_cast<int *>(workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
// Launch reorder kernel
if (params_.ptr_reordered_B != nullptr) {
dim3 grid = ReorderKernel::get_grid_shape(params_);
dim3 block = ReorderKernel::get_block_shape();
cutlass::Kernel<ReorderKernel><<<grid, block, 0, stream>>>(params_);
}
// Launch main kernel
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(32 * kWarpCount, 1, 1);
// Dynamic SMEM size based on input params.
int smem_size = int(params_.get_smem_size());
// Make sure we can use that much shared memory.
cudaError_t status =
cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (status != cudaSuccess)
return Status::kErrorInternal;
cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
cudaError_t result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
int get_smem_size() { return int(params_.get_smem_size()); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -52,33 +52,33 @@ template<typename ImplicitGemmKernel_>
class ImplicitGemmConvolution {
public:
using ImplicitGemmKernel = ImplicitGemmKernel_;
using UnderlyingKernel = ImplicitGemmKernel_;
using ElementA = typename ImplicitGemmKernel::ElementA;
using LayoutA = typename ImplicitGemmKernel::LayoutA;
using ElementB = typename ImplicitGemmKernel::ElementB;
using LayoutB = typename ImplicitGemmKernel::LayoutB;
using ElementC = typename ImplicitGemmKernel::ElementC;
using LayoutC = typename ImplicitGemmKernel::LayoutC;
using ElementAccumulator = typename ImplicitGemmKernel::ElementAccumulator;
using ElementCompute = typename ImplicitGemmKernel::ElementCompute;
using OperatorClass = typename ImplicitGemmKernel::OperatorClass;
using ArchTag = typename ImplicitGemmKernel::ArchTag;
using ThreadblockShape = typename ImplicitGemmKernel::ThreadblockShape;
using WarpShape = typename ImplicitGemmKernel::WarpShape;
using InstructionShape = typename ImplicitGemmKernel::InstructionShape;
using ThreadblockSwizzle = typename ImplicitGemmKernel::ThreadblockSwizzle;
using EpilogueOutputOp = typename ImplicitGemmKernel::EpilogueOutputOp;
static int const kStages = ImplicitGemmKernel::kStages;
static int const kConvDim = ImplicitGemmKernel::kConvDim;
using WarpMmaOperator = typename ImplicitGemmKernel::WarpMmaOperator;
using ArchMmaOperator = typename ImplicitGemmKernel::ArchMmaOperator;
using MathOperator = typename ImplicitGemmKernel::MathOperator;
using ElementA = typename UnderlyingKernel::ElementA;
using LayoutA = typename UnderlyingKernel::LayoutA;
using ElementB = typename UnderlyingKernel::ElementB;
using LayoutB = typename UnderlyingKernel::LayoutB;
using ElementC = typename UnderlyingKernel::ElementC;
using LayoutC = typename UnderlyingKernel::LayoutC;
using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
using ElementCompute = typename UnderlyingKernel::ElementCompute;
using OperatorClass = typename UnderlyingKernel::OperatorClass;
using ArchTag = typename UnderlyingKernel::ArchTag;
using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
using WarpShape = typename UnderlyingKernel::WarpShape;
using InstructionShape = typename UnderlyingKernel::InstructionShape;
using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
static int const kStages = UnderlyingKernel::kStages;
static int const kConvDim = UnderlyingKernel::kConvDim;
using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
using MathOperator = typename UnderlyingKernel::MathOperator;
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator;
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm;
static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport;
static cutlass::conv::GroupMode const kGroupMode = ImplicitGemmKernel::kGroupMode;
static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
static int const kWarpCount =
(ThreadblockShape::kM / WarpShape::kM) *
@ -86,12 +86,12 @@ public:
(ThreadblockShape::kK / WarpShape::kK);
/// Argument structure
using Arguments = typename ImplicitGemmKernel::Arguments;
using Arguments = typename UnderlyingKernel::Arguments;
private:
/// Kernel parameters object
typename ImplicitGemmKernel::Params params_;
typename UnderlyingKernel::Params params_;
public:
@ -102,12 +102,12 @@ public:
static Status can_implement(Arguments const &args) {
// dispatch to iterators
Status status = ImplicitGemmKernel::Mma::IteratorA::can_implement(args.problem_size);
Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
if (Status::kSuccess != status) {
return status;
}
status = ImplicitGemmKernel::Mma::IteratorB::can_implement(args.problem_size);
status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
if (Status::kSuccess != status) {
return status;
}
@ -138,9 +138,15 @@ public:
if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) {
return Status::kErrorInvalidProblem;
}
// current optimized iterator algo only supports SingleGroup mode
if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized &&
kGroupMode != conv::GroupMode::kSingleGroup) {
return Status::kErrorInvalidProblem;
}
}
static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
if (kConvolutionalOperator == conv::Operator::kFprop) {
if (args.problem_size.K % kAlignmentC)
return Status::kErrorMisalignedOperand;
@ -249,15 +255,15 @@ public:
}
// initialize the params structure from the arguments
params_ = typename ImplicitGemmKernel::Params(
params_ = typename UnderlyingKernel::Params(
args,
static_cast<int *>(workspace)
);
int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage));
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<ImplicitGemmKernel>,
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
@ -292,9 +298,9 @@ public:
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(32 * kWarpCount, 1, 1);
int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage));
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
cutlass::Kernel<ImplicitGemmKernel><<<grid, block, smem_size, stream>>>(params_);
cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
cudaError_t result = cudaGetLastError();

View File

@ -89,7 +89,7 @@ template <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage
/// pipeline.
/// pipeline that supports all GroupMode.
template <
typename ElementA,
typename LayoutA,
@ -135,6 +135,13 @@ struct DefaultConv2dGroupFprop <
AlignmentB
> {
static_assert(std::is_same<LayoutA, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
static_assert(std::is_same<LayoutB, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
static_assert(std::is_same<LayoutC, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
@ -215,6 +222,267 @@ struct DefaultConv2dGroupFprop <
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage
/// pipeline that supports GroupMode::kSingleGroup.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dGroupFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
GroupMode::kSingleGroup,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
static_assert(std::is_same<LayoutA, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
static_assert(std::is_same<LayoutB, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
static_assert(std::is_same<LayoutC, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
Stages, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA, LayoutA,
ThreadMapA,
AccessTypeA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB, LayoutB,
ThreadMapB,
AccessTypeB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the Mma
using Mma = threadblock::ImplicitGemmMultistage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
arch::CacheOperation::Always,
IteratorB,
SmemIteratorB,
CacheOpB,
MmaPolicy,
Stages
>;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape,
WarpMmaTensorOp,
kPartitionsK,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv2dProblemSize,
GroupMode::kSingleGroup
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and
/// 2 stage pipeline that supports GroupMode::kSingleGroup.
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
int AlignmentA,
int AlignmentB
>
struct DefaultConv2dGroupFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
MathOperatorTag,
GroupMode::kSingleGroup,
IteratorAlgorithm::kOptimized,
StrideSupport,
AlignmentA,
AlignmentB
> {
static_assert(std::is_same<LayoutA, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
static_assert(std::is_same<LayoutB, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
static_assert(std::is_same<LayoutC, cutlass::layout::TensorNHWC>::value,
"Current group conv only support NHWC layout");
// Define the core components from GEMM
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
2, MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
using IteratorA =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
ThreadMapA,
AccessTypeA
>
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::TileIterator<
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
ThreadMapB,
AccessTypeB
>
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
using MmaPolicy = typename MmaCore::MmaPolicy;
// Define the Mma
using Mma = threadblock::ImplicitGemmPipelined<
ThreadblockShape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
MmaPolicy
>;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
// Define the epilogue
using Epilogue = typename detail::DefaultConvEpilogue<
ArchTag,
ThreadblockShape,
WarpMmaTensorOp,
kPartitionsK,
EpilogueOutputOp
>::Epilogue;
// Define the kernel
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv2dProblemSize,
GroupMode::kSingleGroup
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace conv
} // namespace cutlass

View File

@ -39,14 +39,21 @@
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_conv2d.h"
#include "cutlass/conv/kernel/direct_convolution.h"
#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h"
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h"
// Direct Conv Related Header files
#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h"
#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h"
#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h"
#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -54,7 +61,7 @@ namespace conv {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Conv2dFprop
/// Defines a kernel for DepthwiseFprop
template <
typename ElementA,
typename LayoutA,
@ -80,12 +87,43 @@ template <
int AlignmentB = cutlass::sizeof_bits<ElementB>::value / cutlass::sizeof_bits<ElementB>::value
> struct DefaultDepthwiseFprop;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for DepthwiseFprop with direct convolution algorithm
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename OperatorClass,
typename ArchTag,
typename ThreadblockShape,
typename ThreadBlockOutputShape,
typename FilterShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
// MatrixShape<Height, Width>
typename StrideShape = cutlass::MatrixShape<-1, -1>,
// MatrixShape< Height, Width>
typename DilationShape = cutlass::MatrixShape<-1, -1>,
/// Access granularity of A matrix in units of elements
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
/// Access granularity of B matrix in units of elements
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
> struct DefaultDepthwiseDirect2dConvFprop;
/////////////////////////////////////////////////////////////////////////////////////////////////
// OpClassSimt convolutions
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm,
/// 2 stage pipeline, and FFMA-based mainloop for SM50
/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm
template <
typename ElementA,
typename LayoutA,
@ -210,6 +248,338 @@ struct DefaultDepthwiseFprop <
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Depthwise specialization for direct 2d conv implementation,
/// multiple stage pipeline, and SIMT-based mainloop
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename ThreadBlockOutputShape,
typename FilterShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
typename StrideShape,
typename DilationShape,
int AlignmentA,
int AlignmentB
>
struct DefaultDepthwiseDirect2dConvFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassSimt,
ArchTag,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kOptimized,
StrideSupport,
StrideShape,
DilationShape,
AlignmentA,
AlignmentB
> {
// One warp handles the entrie groups per cta.
static_assert(ThreadblockShape::kN == WarpShape::kN,
"ThreadblockShape::kN should be same as WarpShape::kN ");
static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount,
"ThreadblockShape::kK and WarpShape::kK should be same as filter size");
static_assert(ThreadblockShape::kM % WarpShape::kM == 0,
"ThreadblockShape::kM must be divisible by WarpShape shape::kM");
static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1");
// Define the core components from GEMM
using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize<
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
ElementA,
layout::RowMajor,
ElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
arch::OpClassSimt,
128,
128,
Stages,
MathOperatorTag>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kM,ThreadblockShape::kN>, // < outputShape:KMNK, groups per cta>
ThreadBlockOutputShape,
ElementA, LayoutA,
ThreadMapA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kN, FilterShape::kCount>,
ElementB, LayoutB,
ThreadMapB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
using MmaPolicy = typename MmaCore::MmaPolicy;
using ThreadOutputShape = typename MmaCore::ThreadOutputShape;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * AlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt<
ThreadblockShape, // < outputShape:KMNK, groups per cta>
WarpMmaSimtOp,
EpilogueOutputOp,
EpilogueOutputOp::kCount,
ThreadOutputShape,
ThreadBlockOutputShape
>::Epilogue;
// Define the Mma
using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
MmaPolicy,
Stages,
Epilogue
>;
// Define the kernel
using Kernel = cutlass::conv::kernel::DirectConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv2dProblemSize,
cutlass::conv::GroupMode::kDepthwise,
ThreadBlockOutputShape
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a kernel for Depthwise specialization for direct 2d conv implementation,
/// multiple stage pipeline, and SIMT-based mainloop
template <
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementAccumulator,
typename ArchTag,
typename ThreadblockShape,
typename ThreadBlockOutputShape,
typename FilterShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOutputOp,
typename ThreadblockSwizzle,
int Stages,
typename MathOperatorTag,
conv::StrideSupport StrideSupport,
typename StrideShape,
typename DilationShape,
int AlignmentA,
int AlignmentB
>
struct DefaultDepthwiseDirect2dConvFprop <
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
arch::OpClassSimt,
ArchTag,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
MathOperatorTag,
IteratorAlgorithm::kFixedStrideDilation,
StrideSupport,
StrideShape,
DilationShape,
AlignmentA,
AlignmentB,
> {
// One warp handles the entrie groups per cta.
static_assert(ThreadblockShape::kN == WarpShape::kN,
"ThreadblockShape::kN should be same as WarpShape::kN ");
static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount,
"ThreadblockShape::kK and WarpShape::kK should be same as filter size");
static_assert(ThreadblockShape::kM % WarpShape::kM == 0,
"ThreadblockShape::kM must be divisible by WarpShape shape::kM");
static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1");
static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed");
static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed");
// Activations loaded by threadblock
static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow +
(FilterShape::kRow - 1) * DilationShape::kRow + 1;
static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn +
(FilterShape::kColumn - 1) * DilationShape::kColumn + 1;
using ActivationShape =
cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >;
// Define the core components from GEMM
using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize<
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
ElementA,
layout::RowMajor,
ElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
arch::OpClassSimt,
128,
128,
Stages,
MathOperatorTag,
IteratorAlgorithm::kFixedStrideDilation,
StrideShape,
DilationShape,
ActivationShape>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using IteratorA =
cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation<
cutlass::MatrixShape<ThreadblockShape::kM,ThreadblockShape::kN>, // < outputShape:KMNK, groups per cta>
ThreadBlockOutputShape,
StrideShape,
DilationShape,
ActivationShape,
ElementA, LayoutA,
ThreadMapA
>;
using SmemIteratorA = typename MmaCore::SmemIteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
using IteratorB =
cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<
cutlass::MatrixShape<ThreadblockShape::kN, FilterShape::kCount>,
ElementB, LayoutB,
ThreadMapB
>;
using SmemIteratorB = typename MmaCore::SmemIteratorB;
// Warp-level GEMM components
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
using MmaPolicy = typename MmaCore::MmaPolicy;
using ThreadOutputShape = typename MmaCore::ThreadOutputShape;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * AlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the epilogue
using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt<
ThreadblockShape, // < outputShape:KMNK, groups per cta>
WarpMmaSimtOp,
EpilogueOutputOp,
EpilogueOutputOp::kCount,
ThreadOutputShape,
ThreadBlockOutputShape
>::Epilogue;
// Define the Mma
using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage<
ThreadblockShape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
MmaPolicy,
Stages,
Epilogue,
IteratorAlgorithm::kFixedStrideDilation
>;
// Define the kernel
using Kernel = cutlass::conv::kernel::DirectConvolution<
Mma,
Epilogue,
ThreadblockSwizzle,
conv::Operator::kFprop,
Conv2dProblemSize,
cutlass::conv::GroupMode::kDepthwise,
ThreadBlockOutputShape
>;
};
} // namespace kernel
} // namespace conv

View File

@ -0,0 +1,505 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a multi-staged Depthwise Convolution kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/semaphore.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/conv3d_problem_size.h"
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters structure
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
typename Arguments_, ///! Kernel Arguments
typename ConvOutputIteratorParameter_, ///! Output Iterator Params
typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem
conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> > ///! OutputShape per ThreadBlock
struct DirectConvolutionParams {
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
static Operator const kConvolutionalOperator = ConvOperator;
using ConvProblemSize = ConvProblemSize_;
using Arguments = Arguments_;
using ConvOutputIteratorParameter = ConvOutputIteratorParameter_;
using ThreadblockShape = typename Mma::Shape;
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
static conv::GroupMode const kGroupMode = GroupMode_;
static int const kStages = Mma::kStages;
ConvProblemSize problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
gemm::GemmCoord implicit_gemm_problem_size;
int swizzle_log_tile;
int smem_size_;
int gemm_k_iterations;
int gemm_k_iterations_per_channel;
typename Mma::IteratorA::Params iterator_A;
typename Mma::IteratorA::Element const *ptr_A;
typename Mma::IteratorB::Params iterator_B;
typename Mma::IteratorB::Element const *ptr_B;
typename Mma::IteratorB::Element *ptr_reordered_B;
typename Epilogue::OutputTileIterator::Params iterator_C;
typename Epilogue::OutputTileIterator::Element *ptr_C;
typename Epilogue::OutputTileIterator::Params iterator_D;
typename Epilogue::OutputTileIterator::Element *ptr_D;
typename EpilogueOutputOp::Params output_op;
int *semaphore;
SplitKMode split_k_mode;
int split_k_slices;
//
// Methods
//
CUTLASS_HOST_DEVICE
DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {}
///
CUTLASS_HOST_DEVICE
DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr)
: problem_size(args.problem_size),
implicit_gemm_problem_size(
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)),
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
ptr_A(args.ref_A.data()),
iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())),
ptr_B(args.ref_B.data()),
ptr_reordered_B(args.ref_reordered_B.data()),
iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size),
ptr_C(args.ref_C.data()),
iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size),
ptr_D(args.ref_D.data()),
output_op(args.output_op),
semaphore(semaphore),
split_k_mode(args.split_k_mode),
split_k_slices(args.problem_size.split_k_slices) {
gemm_k_iterations =
depthwise_gemm_k_iterations<ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>(kConvolutionalOperator,
ThreadblockShape::kK,
args.problem_size,
kIteratorAlgorithm,
kGroupMode,
ThreadblockShape::kN);
gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel(
kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm);
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
kConvolutionalOperator,
problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.problem_size.split_k_slices);
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
// Dynamic SMEM usage because stride and dilation are runtime params.
smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size);
}
CUTLASS_HOST_DEVICE
int get_smem_size() {
// Dynamic Smem Size
return smem_size_;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Params_, typename ElementB_>
struct ReorderKernel {
using Params = Params_;
using ElementB = ElementB_;
union SharedStorage {};
static unsigned int const kReorderKernelThreadPerCTA = 128;
CUTLASS_HOST_DEVICE
ReorderKernel() {}
CUTLASS_HOST_DEVICE
static dim3 get_grid_shape(Params const &params) {
return dim3{static_cast<unsigned int>(
(params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) /
kReorderKernelThreadPerCTA),
1,
1};
}
CUTLASS_HOST_DEVICE
static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; }
CUTLASS_HOST_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
int64_t m = static_cast<int64_t>(params.problem_size.groups);
int64_t n = static_cast<int64_t>(params.problem_size.filter_size() / params.problem_size.K);
const ElementB *src_with_type = static_cast<const ElementB *>(params.ptr_B);
ElementB *dst_with_type = static_cast<ElementB *>(params.ptr_reordered_B);
int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x;
int64_t index_m = linear_index / n;
int64_t index_n = linear_index % n;
int64_t new_linear_index = index_m + index_n * m;
if (linear_index < m * n) {
dst_with_type[new_linear_index] = src_with_type[linear_index];
}
return;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem
conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>
>
struct DirectConvolution {
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
static Operator const kConvolutionalOperator = ConvOperator;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename EpilogueOutputOp::ElementOutput;
/// Set output tensor C layout
using LayoutC = LayoutA;
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
using WarpMmaOperator = typename Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename ArchMmaOperator::Operator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename WarpMmaOperator::Shape;
using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>;
static int const kStages = Mma::kStages;
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using TensorRefA = typename Mma::IteratorA::TensorRef;
using TensorRefB = typename Mma::IteratorB::TensorRef;
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
/// Check iterator A and B convolution dimension are the same and
// set device::ImplicitGemmConvolution::kConvDim
static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim,
"Convolution on different different dimensions is not supported");
static int const kConvDim = Mma::IteratorA::kConvDim;
/// Conv dimension and problem size structure (Conv2d or Conv3d)
using ConvProblemSize = ConvProblemSize_;
static conv::GroupMode const kGroupMode = GroupMode_;
//
//
//
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
LayoutC,
typename Epilogue::OutputTileIterator::Layout,
TensorRefC,
ConvOperator,
ConvProblemSize
>;
/// Argument structure
struct Arguments {
//
// Data members
//
ConvProblemSize problem_size;
TensorRefA ref_A;
TensorRefB ref_B;
TensorRefB ref_reordered_B;
TensorRefC ref_C;
TensorRefC ref_D;
typename EpilogueOutputOp::Params output_op;
SplitKMode split_k_mode;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() { }
CUTLASS_HOST_DEVICE
Arguments(
ConvProblemSize const & problem_size
):
problem_size(problem_size) { }
CUTLASS_HOST_DEVICE
Arguments(
ConvProblemSize const & problem_size,
TensorRefA const & ref_A,
TensorRefB const & ref_B,
TensorRefC const & ref_C,
TensorRefC const & ref_D,
typename EpilogueOutputOp::Params const & output_op,
TensorRefB const & ref_reordered_B = nullptr,
SplitKMode const & split_k_mode = SplitKMode::kSerial
):
problem_size(problem_size),
ref_A(ref_A),
ref_B(ref_B),
ref_C(ref_C),
ref_D(ref_D),
output_op(output_op),
ref_reordered_B(ref_reordered_B),
split_k_mode(split_k_mode)
{
}
};
using Params =
typename cutlass::conv::kernel::DirectConvolutionParams<Mma,
Epilogue,
ThreadblockSwizzle,
kConvolutionalOperator,
Arguments,
ConvOutputIteratorParameter,
ConvProblemSize,
kGroupMode,
ThreadBlockOutputShape>;
using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel<Params, ElementB>;
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
DirectConvolution() { }
/// Executes one ImplicitGEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_idx =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if threadblock is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
return;
}
// Compute position within threadblock
int thread_idx = threadIdx.x;
int iterator_column_offset = 0;
int filter_row_offset = 0;
if (kGroupMode != GroupMode::kNone) {
if (kGroupMode == GroupMode::kDepthwise) {
iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN;
}
}
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.iterator_A,
params.problem_size,
params.ptr_A,
thread_idx,
MatrixCoord(
threadblock_tile_idx.m() + threadblock_tile_idx.k(),
iterator_column_offset
)
);
typename Mma::IteratorB iterator_B(
params.iterator_B,
params.problem_size,
params.ptr_reordered_B,
thread_idx,
MatrixCoord(
filter_row_offset,
iterator_column_offset
)
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
// Compute logical position within grid
threadblock_tile_idx =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
MatrixCoord threadblock_offset(
threadblock_tile_idx.m() + threadblock_tile_idx.k(),
threadblock_tile_idx.n() * Mma::Shape::kN
);
// Tile iterator writing to destination tensor
typename Epilogue::OutputTileIterator iterator_D(
params.iterator_D,
params.ptr_D,
ConvOutputIteratorParameter::extent(params.problem_size),
thread_idx,
threadblock_offset
);
// Tile iterator reading from source accumulator tensor
typename Epilogue::OutputTileIterator iterator_C(
params.iterator_C,
params.ptr_C,
ConvOutputIteratorParameter::extent(params.problem_size),
thread_idx,
threadblock_offset
);
// Construct the epilogue
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Compute threadblock-scoped matrix multiply-add
// Epilogue is fused in the mainloop
mma(params.gemm_k_iterations,
accumulators,
iterator_A,
params.iterator_A,
iterator_B,
params.iterator_B,
accumulators,
epilogue,
output_op,
iterator_D,
iterator_C,
params.split_k_slices);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,325 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates exposing architecture support for depthwise convolution
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/thread/mma.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// MMA operation
template <
/// Size of the matrix product (concept: GemmShape)
typename Shape_,
/// Number of threads participating
int kThreads_,
/// Data type of A elements
typename ElementA,
/// Data type of B elements
typename ElementB,
/// Element type of C matrix
typename ElementC,
/// Inner product operator
typename Operator
>
struct ElementwiseInnerProduct;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// General implementation
template <
/// Size of the matrix product (concept: GemmShape)
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Data type of B elements
typename ElementB_,
/// Element type of C matrix
typename ElementC_>
struct ElementwiseInnerProduct<Shape_, 1, ElementA_, ElementB_, ElementC_, arch::OpMultiplyAdd> {
using Shape = Shape_;
using Operator = arch::OpMultiplyAdd;
using ElementC = ElementC_;
CUTLASS_HOST_DEVICE
void operator()(Array<ElementC_, Shape::kN> &d,
Array<ElementA_, Shape::kN> const &a,
Array<ElementB_, Shape::kN> const &b,
Array<ElementC_, Shape::kN> const &c) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Shape::kN; ++i) {
d[i] = a[i] * b[i] + c[i];
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Specialization of half_t
template <>
struct ElementwiseInnerProduct<
gemm::GemmShape<2, 2, 1>,
1,
half_t,
half_t,
half_t,
arch::OpMultiplyAdd> {
using Shape = gemm::GemmShape<2, 2, 1>;
using Operator = arch::OpMultiplyAdd;
using ElementC = half_t;
CUTLASS_HOST_DEVICE
void operator()(
Array<half_t, 2> &d,
Array<half_t, 2> const &a,
Array<half_t, 2> const &b,
Array<half_t, 2> const &c
) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
__half2 const & A = reinterpret_cast<__half2 const &>(a);
__half2 const & B = reinterpret_cast<__half2 const &>(b);
__half2 const & C = reinterpret_cast<__half2 const &>(c);
__half2 tmp_D = __hfma2(A, B, C);
d = reinterpret_cast<Array<half_t, 2> const &>(tmp_D);
#else
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; ++i) {
d[i] = a[i] * b[i] + c[i];
}
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape,
/// Data type of A elements
typename ElementA,
/// Data type of B elements
typename ElementB,
/// Element type of C matrix
typename ElementC,
/// Concept: arch::OpMultiplyAdd or arch::Mma<>
typename Operator = arch::OpMultiplyAdd,
/// Used for partial specialization
typename Enable = bool
>
struct DepthwiseDirectConvElementwiseInnerProduct;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Gemplate that handles all packed matrix layouts
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Data type of B elements
typename ElementB_,
/// Element type of C matrix
typename ElementC_,
/// Operator used to compute GEMM
typename Operator_
>
struct DepthwiseDirectConvElementwiseInnerProductGeneric {
/// Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
/// Data type of operand A
using ElementA = ElementA_;
/// Data type of operand B
using ElementB = ElementB_;
/// Element type of operand C
using ElementC = ElementC_;
/// Underlying mathematical operator
using Operator = Operator_;
/// A operand storage
using FragmentA = Array<ElementA, Shape::kMN>;
/// B operand storage
using FragmentB = Array<ElementB, Shape::kN>;
/// C operand storage
using FragmentC = Array<ElementC, Shape::kMN>;
/// Instruction
using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct<
gemm::GemmShape<Shape::kN, Shape::kN, 1>,
1,
ElementA,
ElementB,
ElementC,
Operator>;
//
// Methods
//
/// Computes a matrix product D = A * B + C
CUTLASS_HOST_DEVICE
void operator()(
FragmentC & D,
FragmentA const & A,
FragmentB const & B,
FragmentC const & C) {
Array<ElementC, Shape::kN> *ptr_D = reinterpret_cast<Array<ElementC, Shape::kN> *>(&D);
Array<ElementA, Shape::kN> const *ptr_A =
reinterpret_cast<Array<ElementA, Shape::kN> const *>(&A);
Array<ElementB, Shape::kN> const *ptr_B =
reinterpret_cast<Array<ElementB, Shape::kN> const *>(&B);
MmaOp mma_op;
// Copy accumulators
D = C;
// Compute matrix product
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Shape::kM; ++m) {
Array<ElementC, MmaOp::Shape::kN> tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n];
Array<ElementA, MmaOp::Shape::kN> tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n];
Array<ElementB, MmaOp::Shape::kN> tmpB = ptr_B[n];
mma_op(tmpD, tmpA, tmpB, tmpD);
ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD;
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Data type of B elements
typename ElementB_,
/// Element type of C matrix
typename ElementC_
>
struct DepthwiseDirectConvElementwiseInnerProduct<
Shape_,
ElementA_,
ElementB_,
ElementC_,
arch::OpMultiplyAdd
> {
/// Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
/// Data type of operand A
using ElementA = ElementA_;
/// Data type of operand B
using ElementB = ElementB_;
/// Element type of operand C
using ElementC = ElementC_;
/// Underlying mathematical operator
using Operator = arch::OpMultiplyAdd;
/// A operand storage
using FragmentA =
Array<ElementA, Shape::kMN>; // output_tile_size per thread * groups_per_thread
/// B operand storage
using FragmentB = Array<ElementB, Shape::kN>; // 1 * groups_per_thread
/// C operand storage
using FragmentC =
Array<ElementC, Shape::kMN>; // output_tile_size per thread * groups_per_thread
static bool const use_optimized = 0;
using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric<Shape,
ElementA,
ElementB,
ElementC,
Operator>;
//
// Methods
//
/// Computes a matrix product D = A * B + C
CUTLASS_HOST_DEVICE
void operator()(
FragmentC & D,
FragmentA const & A,
FragmentB const & B,
FragmentC const & C) {
ArchMmaOperator mma;
mma(D, A, B, C);
}
};
} // namespace thread
} // namespace conv
} // namespace cutlass

View File

@ -145,6 +145,7 @@ private:
uint32_t predicates_[kAccessesPerVector];
int filter_rs_;
int filter_c_;
int channels_per_group_;
//
// Assertions
@ -175,6 +176,7 @@ public:
filter_c_ = threadblock_offset.row() + thread_coord.contiguous();
Index column = threadblock_offset.column() + thread_coord.strided();
channels_per_group_ = problem_size_.C / problem_size_.groups;
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
@ -188,7 +190,7 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_);
}
pointer_ += (
@ -229,7 +231,7 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_);
}
pointer_ += next;

View File

@ -0,0 +1,230 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Extracts the host-params objects into non-template code.
*/
#pragma once
#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED
#include <fstream>
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized
template<typename Layout_ = layout::TensorNHWC >
struct Depthwise2dFpropDirectConvParams;
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation
template<typename Layout_ = layout::TensorNHWC >
struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams;
/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized
template<typename Layout_ = layout::TensorNHWC >
struct Depthwise2dFpropDirectConvFilterIteratorParams;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized
template<>
struct Depthwise2dFpropDirectConvParams<layout::TensorNHWC> {
using Layout = layout::TensorNHWC;
Layout layout;
int32_t activation_tile_h;
int32_t activation_tile_w;
int32_t activation_tile_hw;
FastDivmod activation_tile_w_divmod;
int filter[2];
int stride[2];
int dilation[2];
int inc_next[2];
FastDivmod pq_divmod;
FastDivmod q_divmod;
int activation_load_count;
int activation_storage_elements;
int activation_size;
//
// Methods
//
CUTLASS_HOST_DEVICE
Depthwise2dFpropDirectConvParams() { }
CUTLASS_HOST_DEVICE
Depthwise2dFpropDirectConvParams(
Conv2dProblemSize const &problem_size,
Layout const &layout, ///< layout object
MatrixCoord threadblock_shape, ///< CTA threadblock Shape
Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock
const int element_size_bits, ///< bits of activation element
const int thread_count, ///< threads per threadblock
const int thread_count_contiguous, ///< number of threads for continuous dimension
const int element_per_load) ///< element per each load
: layout(layout) {
filter[0] = problem_size.S;
filter[1] = problem_size.R;
stride[0] = problem_size.stride_w;
stride[1] = problem_size.stride_h;
dilation[0] = problem_size.dilation_w;
dilation[1] = problem_size.dilation_h;
// Compute activation_tile size per threadblock because stride and dilation are runtime params.
activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h +
(problem_size.R - 1) * problem_size.dilation_h + 1;
activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w +
(problem_size.S - 1) * problem_size.dilation_w + 1;
activation_tile_hw = activation_tile_h * activation_tile_w;
activation_tile_w_divmod = FastDivmod(activation_tile_w);
/// Below two values could not be templatized because the stride and dilation are runtime params
activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count;
activation_storage_elements = activation_load_count * element_per_load * thread_count;
activation_size = activation_storage_elements * element_size_bits / 8;
// Fastdivmod for output P, Q
int tiles_p =
(problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h());
int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) /
(threadblock_output_shape.w());
pq_divmod = FastDivmod(tiles_p * tiles_q);
q_divmod = FastDivmod(tiles_q);
// next S
inc_next[0] = problem_size.dilation_w;
// next R
inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation
template <>
struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams<layout::TensorNHWC> {
using Layout = layout::TensorNHWC;
Layout layout;
FastDivmod pq_divmod;
FastDivmod q_divmod;
int activation_size;
//
// Methods
//
CUTLASS_HOST_DEVICE
Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {}
CUTLASS_HOST_DEVICE
Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams(
Conv2dProblemSize const &problem_size,
Layout const &layout, ///< Layout object
MatrixCoord threadblock_shape, ///< Threadblock Shape
Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock
const int activation_size_ ///< Activation size loaded by iterator
)
: layout(layout),
activation_size(activation_size_) {
// Fastdivmod for output P, Q
int tiles_p =
(problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h());
int tiles_q =
(problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w());
pq_divmod = FastDivmod(tiles_p * tiles_q);
q_divmod = FastDivmod(tiles_q);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized
template <>
struct Depthwise2dFpropDirectConvFilterIteratorParams<layout::TensorNHWC> {
using Layout = layout::TensorNHWC;
Layout layout;
int filter_size;
bool is_convolution;
//
// Methods
//
CUTLASS_HOST_DEVICE
Depthwise2dFpropDirectConvFilterIteratorParams() {}
CUTLASS_HOST_DEVICE
Depthwise2dFpropDirectConvFilterIteratorParams(
Conv2dProblemSize const &problem_size,
Layout const &layout, ///< Layout object
MatrixCoord threadblock_shape, ///< Threadblock Shape
const int filter_size_) ///< Filter size loaded by iterator
: layout(layout),
filter_size(filter_size_),
is_convolution(problem_size.mode == Mode::kConvolution){}
};
} // namespace threadblock
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,314 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile)
matrix from memory.
This iterator assumes TensorNHWC layout of tensors in Global Memory.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h"
#include "cutlass/coord.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Shape_,
typename OutputTileShape_,
typename StrideShape_,
typename DilationShape_,
typename ActivationShape_,
typename Element_,
typename Layout_,
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess> >
class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation {
public:
//
// Types
//
using Shape = Shape_;
using OutputTileShape = OutputTileShape_;
using Element = Element_;
using Layout = Layout_;
using TensorCoord = typename Layout::TensorCoord;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
// Compilation value of stride , dialtion and activation shape
using StrideShape = StrideShape_;
using DilationShape = DilationShape_;
using ActivationShape = ActivationShape_;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads *
sizeof_bits<Element>::value / 8;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Simplifying assertions
//
static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1");
static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1");
static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock");
//
// Parameters structure
//
using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams<Layout>;
private:
Conv2dProblemSize const &problem_size_;
Params const &params_;
char const *pointer_;
// Base channels for current threadblock
int base_c_;
// Base activation index for current threadblock
int offset_intial_npq_;
// Base activation coord for current threadblock
TensorCoord activatioin_base_;
// Intial thread positioin
int offset_initial_hwc_;
// Overall load instruction per thread.
int iterator_load_;
// thread loading position.
int iterator_hwc_;
// activation N is inside the Tensor or not
bool valid_n_;
public:
CUTLASS_HOST_DEVICE
DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation(
Params const &params,
Conv2dProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
MatrixCoord const &threadblock_offset =
MatrixCoord()
)
: params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
offset_intial_npq_(threadblock_offset.row()),
offset_initial_hwc_(thread_idx),
iterator_load_(0) {
base_c_ = threadblock_offset.column();
set_iteration_index(0);
set_activation_coord(offset_intial_npq_);
}
CUTLASS_HOST_DEVICE
void set_activation_coord(int offset_npq) {
int offset_inital_n, offset_inital_p, offset_inital_q;
int residual;
params_.pq_divmod(offset_inital_n, residual, offset_npq);
params_.q_divmod(offset_inital_p, offset_inital_q, residual);
int base_n = offset_inital_n;
int base_h =
offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h;
int base_w =
offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w;
activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_);
valid_n_ = activatioin_base_.n() < problem_size_.N;
}
CUTLASS_HOST_DEVICE
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
return Params(
problem_size,
layout,
{Shape::kRow, Shape::kColumn},
{OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC},
kActivationSize);
}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads;
iterator_load_ = index;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_HOST_DEVICE
void advance() {
// Go to next threadblock
offset_intial_npq_ += problem_size_.split_k_slices;
set_iteration_index(0);
set_activation_coord(offset_intial_npq_);
}
/// Returns the coordinate in the activations tensor X that is currently pointed to
/// by the iterator.
CUTLASS_HOST_DEVICE
TensorCoord at() const {
int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ;
int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ;
int h = next / ActivationShape::kW;
int w = next % ActivationShape::kW;
c = c * AccessType::kElements;
return activatioin_base_ + TensorCoord(0, h, w, c);
}
/// Returns true if the current coordinate is within the activations tensor X
CUTLASS_HOST_DEVICE
bool valid() const {
TensorCoord coord = at();
bool valid_c = coord.c() < problem_size_.C;
bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H;
bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W;
return valid_n_ ? valid_c & valid_h & valid_w : 0;
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
AccessType const *ptr =
reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return ptr;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() {
++iterator_load_;
iterator_hwc_ += ThreadMap::kThreads;
if (iterator_load_ < ThreadMap::Iterations::kCount) {
return *this;
}
iterator_load_ = 0;
iterator_hwc_ = offset_initial_hwc_;
return *this;
}
/// Determines the activation size loaded by iterator
CUTLASS_HOST_DEVICE
int get_load_size() {
return kActivationSize;
}
/// Determines the iterations needed
CUTLASS_HOST_DEVICE
int get_iteration_num() {
return ThreadMap::Iterations::kCount;
}
/// Determines whether the Depthwise fprop can execute the given problem.
CUTLASS_HOST_DEVICE
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check stride and dilation constraint
if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) {
return Status::kErrorInvalidProblem;
}
if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) {
return Status::kErrorInvalidProblem;
}
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,291 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile)
matrix from memory.
This iterator assumes TensorNHWC layout of tensors in Global Memory.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h"
#include "cutlass/coord.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Shape_,
typename OutputTileShape_,
typename Element_,
typename Layout_,
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess> >
class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized {
public:
//
// Types
//
using Shape = Shape_;
using OutputTileShape = OutputTileShape_;
using Element = Element_;
using Layout = Layout_;
using TensorCoord = typename Layout::TensorCoord;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Simplifying assertions
//
static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1");
static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1");
static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock");
//
// Parameters structure
//
using Params = Depthwise2dFpropDirectConvParams<Layout>;
private:
Conv2dProblemSize const &problem_size_;
Params const &params_;
char const *pointer_;
// Base channels for current threadblock
int base_c_;
// Base activation index for current threadblock
int offset_intial_npq_;
// Base activation coord for current threadblock
TensorCoord activatioin_base_;
// Intial thread positioin
int offset_initial_hwc_;
// Overall load instruction per thread.
int iterator_load_;
// thread loading position.
int iterator_hwc_;
// Number of loads for activations tensor X.
const int number_of_loads_;
public:
CUTLASS_HOST_DEVICE
DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized(
Params const &params,
Conv2dProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
MatrixCoord const &threadblock_offset =
MatrixCoord()
)
: params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
offset_intial_npq_(threadblock_offset.row()),
offset_initial_hwc_(thread_idx),
iterator_load_(0),
number_of_loads_(params.activation_load_count) {
base_c_ = threadblock_offset.column();
set_activation_coord(offset_intial_npq_);
set_iteration_index(0);
}
CUTLASS_HOST_DEVICE
void set_activation_coord(int offset_npq) {
int offset_inital_n, offset_inital_p, offset_inital_q;
int residual;
params_.pq_divmod(offset_inital_n, residual, offset_npq);
params_.q_divmod(offset_inital_p, offset_inital_q, residual);
int base_n = offset_inital_n;
int base_h =
offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h;
int base_w =
offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w;
activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_);
}
CUTLASS_HOST_DEVICE
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
return Params(
problem_size,
layout,
{Shape::kRow, Shape::kColumn},
{OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC},
sizeof_bits<Element>::value,
ThreadMap::kThreads,
ThreadMap::Detail::ShapeVec::kContiguous,
ThreadMap::kElementsPerAccess);
}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads;
iterator_load_ = index;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_HOST_DEVICE
void advance() {
// Go to next threadblock
offset_intial_npq_ += problem_size_.split_k_slices;
set_activation_coord(offset_intial_npq_);
}
/// Returns the coordinate in the activations tensor X that is currently pointed to
/// by the iterator.
CUTLASS_HOST_DEVICE
TensorCoord at() const {
int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ;
int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ;
int h, w;
params_.activation_tile_w_divmod(h, w, next) ;
c = c * AccessType::kElements;
return activatioin_base_ + TensorCoord(0, h, w, c);
}
/// Returns true if the current coordinate is within the activations tensor X
CUTLASS_HOST_DEVICE
bool valid() const {
TensorCoord coord = at();
return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H &&
coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C;
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
TensorCoord coord = at();
LongIndex offset = params_.layout(coord);
AccessType const *ptr =
reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
return ptr;
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() {
++iterator_load_;
iterator_hwc_ += ThreadMap::kThreads;
if (iterator_load_ < number_of_loads_) {
return *this;
}
iterator_load_ = 0;
iterator_hwc_ = offset_initial_hwc_;
return *this;
}
/// Determines the activation size loaded by iterator
CUTLASS_HOST_DEVICE
int get_load_size() {
return params_.activation_size;
}
/// Determines the iterations needed
CUTLASS_HOST_DEVICE
int get_iteration_num() {
return number_of_loads_;
}
/// Determines whether the Depthwise fprop can execute the given problem.
CUTLASS_HOST_DEVICE
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.C % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,551 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/cache_operation.h"
#include "cutlass/conv/threadblock/depthwise_mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Epilogue stores the data into global memory
typename Epilogue_,
/// iterator implementation variants
conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized,
/// Used for partial specialization
typename Enable = bool>
class DepthwiseFpropDirectConvMultipleStage :
public DepthwiseDirectConvMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = DepthwiseDirectConvMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Policy describing tuning details
using Policy = Policy_;
using Epilogue = Epilogue_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_;
//
// Dependent types
//
/// Fragment of accumulator tile
using ElementC = typename Policy::Operator::ElementC;
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Internal structure exposed for introspection.
struct Detail {
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
};
private:
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DepthwiseFpropDirectConvMultipleStage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(IteratorA &iterator_A,
IteratorB &iterator_B,
int group_start_A = 0,
int group_start_B = 0) {
if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) {
// Number of iterators is a static value.
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
} else {
// Number of iterators is a runtime value.
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < iterator_A.get_iteration_num(); ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC &accum,
///< iterator over A operand in global memory
IteratorA &iterator_A,
///< Params of global memory iterator
typename IteratorA::Params const &iterator_a_params,
///< iterator over B operand in global memory
IteratorB &iterator_B,
///< Params of global memory iterator
typename IteratorB::Params const &iterator_b_params,
///< initial value of accumulator
FragmentC const &src_accum,
/// Epilogue
Epilogue &epilogue,
///< Output operator
typename Epilogue::OutputOp const &output_op,
///< Tile iterator for destination
typename Epilogue::OutputTileIterator &destination_iterator,
///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
typename Epilogue::OutputTileIterator &source_iterator,
int split_k_slices = 1
) {
//
// Prologue
//
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
if (stage == 0) {
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
}
if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){
// Number of iterators is compilation static.
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
} else {
// Number of iterators is a runtime value.
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num());
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < iterator_A.get_iteration_num(); ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
}
// Move to the next stage
iterator_A.advance();
this->smem_iterator_A_.add_tile_offset({1, 0});
// Inserts a fence to group cp.async instructions into stages.
cutlass::arch::cp_async_fence();
}
/////////////////////////////////////////////////////////////////////////////
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA warp_loaded_frag_A[2];
WarpLoadedFragmentB warp_loaded_frag_B[2];
WarpTransformedFragmentA warp_transformed_frag_A[2];
WarpTransformedFragmentB warp_transformed_frag_B[2];
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
//
// Mainloop
//
unsigned int iterations = 0;
constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2);
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile.
accum.clear();
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) {
if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
}
if (warp_mma_k > 0)
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
warp_loaded_frag_A[warp_mma_k % 2],
warp_loaded_frag_B[warp_mma_k % 2]);
// Issue global->shared copies for the next stage
int group_start_iteration_A, group_start_iteration_B;
if (warp_mma_k == 0) {
group_start_iteration_A = 0;
group_start_iteration_B = 0;
copy_tiles_and_advance(
iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
}
if (warp_mma_k < Base::kWarpGemmIterations) {
warp_mma(
accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
accum
);
}
if (warp_mma_k + 1 == inner_loop_iterations)
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
if (warp_mma_k + 2 == inner_loop_iterations) {
// Inserts a fence to group cp.async instructions into stages.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages of cp.async have committed
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next cta
iterator_A.advance();
this->smem_iterator_A_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size());
smem_read_stage_idx = 0;
} else {
this->warp_tile_iterator_A_.advance(iterator_A.get_load_size());
++smem_read_stage_idx;
}
if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) {
this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params);
}
// goback to start position. B has no multiple stage
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0});
--gemm_k_iterations;
}
}
//
// Epilogue
//
int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size();
destination_iterator.set_tile_index(iterations * split_k_slices);
source_iterator.set_tile_index(iterations * split_k_slices);
epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset);
++iterations;
}
// Insert fence and wait for all outstanding cp.async operations to commit.
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,261 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
matrix from memory.
This iterator assumes TensorNHWC layout of tensors in Global Memory.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/threadblock/conv2d_params.h"
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
template <typename Shape_,
typename Element_,
typename Layout_,
typename ThreadMap_,
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess> >
class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized {
public:
//
// Types
//
using Shape = Shape_;
using Element = Element_;
using Layout = Layout_;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using TensorRef = cutlass::TensorRef<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
static int const kConvDim = 2;
using ConvProblemSize = typename conv::Conv2dProblemSize;
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads *
sizeof_bits<Element>::value / 8;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
//
// Simplifying assertions
//
static_assert(ThreadMap::Iterations::kContiguous == 1,
"Require Iterations::kContiguous == 1");
//
// Parameters structure
//
using Params = Depthwise2dFpropDirectConvFilterIteratorParams<Layout>;
protected:
Conv2dProblemSize const &problem_size_;
Params const &params_;
LongIndex iteration_contiguous_;
LongIndex iteration_strided_;
LongIndex iteration_vector_;
char const *pointer_;
int filter_k_;
int offset_trs_[ThreadMap::Iterations::kStrided];
public:
CUTLASS_HOST_DEVICE
DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized(
Params const &params,
Conv2dProblemSize const &problem_size,
Element const *ptr,
int thread_idx,
MatrixCoord const &threadblock_offset = MatrixCoord()
):
params_(params),
problem_size_(problem_size),
pointer_(reinterpret_cast<char const *>(ptr)),
filter_k_(0) {
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
}
set_iteration_index(0);
}
CUTLASS_HOST_DEVICE
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize);
}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(Index index) {
iteration_vector_ = index % kAccessesPerVector;
int residual_access = index / kAccessesPerVector;
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
pointer_ += pointer_offset * 8 / sizeof_bits<Element>::value;
}
CUTLASS_HOST_DEVICE
void advance() {
// Do nothing because the filter is persistent in the SMEM
}
/// Returns the coordinate in the filter tensor W that is currently pointed to
/// by the iterator.
CUTLASS_HOST_DEVICE
TensorCoord at() const {
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
int trs = offset_trs_[iteration_strided_];
return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix
}
/// Returns true if the current coordinate is within the activations tensor W
CUTLASS_HOST_DEVICE
bool valid() const {
TensorCoord coord = at();
return coord.n() < problem_size_.K &&
coord.h() < Shape::kColumn;
}
/// Returns a pointer to the vector starting at the current coordinate
CUTLASS_HOST_DEVICE
AccessType const *get() const {
TensorCoord coord = at();
int64_t offset = coord.n();
if (params_.is_convolution) {
offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K;
} else {
offset += coord.h() * problem_size_.K;
}
return reinterpret_cast<AccessType const *>(pointer_ +
offset * sizeof_bits<Element>::value / 8);
}
/// Increments to the next memory access
CUTLASS_HOST_DEVICE
DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() {
++iteration_vector_;
if (iteration_vector_ < kAccessesPerVector) {
return *this;
}
iteration_vector_ = 0;
++iteration_contiguous_;
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
return *this;
}
iteration_contiguous_ = 0;
++iteration_strided_;
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
return *this;
}
iteration_strided_ = 0;
return *this;
}
/// Determines the filter size loaded by iterator
CUTLASS_HOST_DEVICE
int get_load_size() {
return kFilterSize;
}
/// Determines whether the Implicit GEMM can execute the given problem.
CUTLASS_HOST_DEVICE
static Status can_implement(Conv2dProblemSize const &problem_size) {
// check alignment constraint on iterator's contiguous dimension
if (problem_size.K % AccessType::kElements) {
return Status::kErrorInvalidProblem;
}
// check whether runtime filter size is same as templated filter size.
if ((problem_size.R * problem_size.S) != Shape::kColumn) {
return Status::kErrorInvalidProblem;
}
return Status::kSuccess;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,229 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a directconv threadblock-scoped Depthwise kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace conv {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Policy object describing MmaTensorOp
template <
/// Warp-level GEMM operator (concept: gemm::warp::Mma)
typename Operator_,
/// Padding used for A operand in shared memory (concept: MatrixShape)
typename SmemPaddingA_,
/// Padding used for B operand in shared memory (concept: MatrixShape)
typename SmemPaddingB_,
///
typename ThreadMapA_,
///
typename ThreadMapB_,
/// Number of partitions of K dimension of GEMM
int PartitionsK = 1>
struct DepthwiseDirectConvMmaPolicy {
/// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt)
using Operator = Operator_;
/// Padding used for A operand in shared memory
using SmemPaddingA = SmemPaddingA_;
/// Padding used for B operand in shared memory
using SmemPaddingB = SmemPaddingB_;
using ThreadMapA = ThreadMapA_;
using ThreadMapB = ThreadMapB_;
/// Number of partitions of K dimension
static int const kPartitionsK = PartitionsK;
};
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class DepthwiseDirectConvMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = cutlass::gemm::
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
/// kWarpGemmIterations could be even and odd.
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<1, // Not determined at compile-time :(
Shape::kN + Policy::SmemPaddingA::kRow>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Policy::ThreadMapB::StorageShape::kStrided +
Policy::SmemPaddingB::kRow, // filter_rs_size
Policy::ThreadMapB::StorageShape::kContiguous +
Policy::SmemPaddingB::kColumn>; // Tile N = 64?
public:
//
// Data members
//
// Let persistent B matrix in front of dynamic matrix A
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer for A operand
/// Not be determined at compile-time -- Just to get a Smem start address.
AlignedBuffer<typename Operator::ElementA, 1> operand_A;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; }
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; }
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DepthwiseDirectConvMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace conv
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -44,11 +44,17 @@
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/warp/mma_depthwise_simt.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
#include "cutlass/gemm/threadblock/mma_singlestage.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/conv/warp/mma_depthwise_simt.h"
#include "cutlass/conv/threadblock/depthwise_mma_base.h"
#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h"
#include "cutlass/arch/cache_operation.h"
@ -58,6 +64,95 @@ namespace cutlass {
namespace conv {
namespace threadblock {
namespace detail {
//
// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by
// each partitions within warp.
// The goal is for each thread's tile of elements to be as square as
// possible for performance (4x4 will be faster than 2x8).
template<int WarpShapeM, // The number of elements (1D) contained in the entire warp
int WarpNumThreadsM> // The number of partitions within the warp
struct SimtWarpShape {
// kP * kQ * WarpNumThreadsM = WarpShapeM
// If needed, enable more specializations.
};
template <>
struct SimtWarpShape<4, 4> {
static constexpr int kP = 1;
static constexpr int kQ = 1;
};
template <>
struct SimtWarpShape<4, 2> {
static constexpr int kP = 2;
static constexpr int kQ = 1;
};
template <>
struct SimtWarpShape<4, 1> {
static constexpr int kP = 2;
static constexpr int kQ = 2;
};
template <>
struct SimtWarpShape<8, 1> {
static constexpr int kP = 2;
static constexpr int kQ = 4;
};
template <>
struct SimtWarpShape<8, 2> {
static constexpr int kP = 2;
static constexpr int kQ = 2;
};
template <>
struct SimtWarpShape<8, 4> {
static constexpr int kP = 1;
static constexpr int kQ = 2;
};
template <>
struct SimtWarpShape<16, 1> {
static constexpr int kP = 4;
static constexpr int kQ = 4;
};
template <>
struct SimtWarpShape<16, 2> {
static constexpr int kP = 2;
static constexpr int kQ = 4;
};
template <>
struct SimtWarpShape<16, 4> {
static constexpr int kP = 2;
static constexpr int kQ = 2;
};
template <int WarpNumThreadsM>
struct SimtWarpShape<25, WarpNumThreadsM> {
static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads");
static constexpr int kP = 5;
static constexpr int kQ = 5;
};
template <>
struct SimtWarpShape<32, 1> {
static constexpr int kP = 4;
static constexpr int kQ = 8;
};
template <>
struct SimtWarpShape<32, 2> {
static constexpr int kP = 4;
static constexpr int kQ = 4;
};
template <>
struct SimtWarpShape<32, 4> {
static constexpr int kP = 2;
static constexpr int kQ = 4;
};
} // namespace detail
template <
/// Shape of threadblock-scoped matrix multiply operator
typename Shape,
@ -114,6 +209,74 @@ struct DepthwiseMmaCoreWithLaneAccessSize;
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Shape of threadblock-scoped matrix multiply operator
typename Shape,
/// Shape of threadblock-scoped output tile
typename ThreadBlockOutputShape,
/// Shape of filter shape per threadblock
typename FilterShape,
/// Shape of warp-level matrix multiply operator
typename WarpShape,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape,
/// Element data type of A operand
typename ElementA,
/// Layout of operand A
typename LayoutA,
/// Element data type of B operand
typename ElementB,
/// Layout of operand B
typename LayoutB,
/// Data type of accumulator
typename ElementC,
/// Layout of accumulator
typename LayoutC,
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
typename OperatorClass,
/// Size of a warp-scoped per thread access
int kLaneAccessSizeA_ = 0,
/// Size of a warp-scoped per thread access
int kLaneAccessSizeB_ = 0,
/// Number of stages
int Stages = 2,
/// Operation performed by MMA
typename Operator = typename platform::conditional<
(platform::is_same<OperatorClass,
cutlass::arch::OpClassTensorOp>::value) &&
(platform::is_same<ElementA, int8_t>::value ||
platform::is_same<ElementA, int4b_t>::value ||
platform::is_same<ElementA, uint8_t>::value ||
platform::is_same<ElementA, uint4b_t>::value),
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type,
/// Iterator algo type
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
/// Stride ( MatrixShape<Height, Width> )
typename StrideShape = cutlass::MatrixShape<-1, -1>,
/// Dilation ( MatrixShape<Height, Width> )
typename DilationShape = cutlass::MatrixShape<-1, -1>,
/// Activation Shape loaded by threadblock
typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA =
cutlass::arch::CacheOperation::Global,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB =
cutlass::arch::CacheOperation::Global,
/// per-element transformation for elements of A
ComplexTransform TransformA = ComplexTransform::kNone,
/// per-element transformation for elements of B
ComplexTransform TransformB = ComplexTransform::kNone,
bool IsComplex = false // (is_complex<ElementA>::value || is_complex<ElementB>::value)
>
struct DepthwiseDirectConvMmaCoreWithLaneAccessSize;
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Shape of threadblock-scoped matrix multiply operator
typename Shape,
@ -332,6 +495,458 @@ struct DepthwiseMmaCoreWithLaneAccessSize<Shape_,
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: row-major
/// B: row-major
/// Operator: simt class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of threadblock-scoped output tile (concept: TensorNHWCShape)
typename ThreadBlockOutputShape_,
/// Shape of filter shape per threadblock
typename FilterShape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Size of a warp-scoped per thread access
int kLaneAccessSizeA_,
/// Number of stages
int Stages_,
/// Operation performed by GEMM
typename Operator_>
struct DepthwiseDirectConvMmaCoreWithLaneAccessSize<Shape_,
ThreadBlockOutputShape_,
FilterShape_,
WarpShape_,
cutlass::gemm::GemmShape<1, 1, 1>,
ElementA_,
layout::RowMajor,
ElementB_,
layout::ColumnMajor,
ElementC_,
LayoutC_,
arch::OpClassSimt,
kLaneAccessSizeA_,
128,
Stages_,
Operator_> {
using Shape = Shape_;
using FilterShape = FilterShape_;
using WarpShape = WarpShape_;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = ElementB_;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using OperatorClass = arch::OpClassSimt;
static int const kLaneAccessSizeB = 128;
// Divisility requirements
static_assert( kLaneAccessSizeB > 0,
"Size of a warp-scoped per thread access should be larger then ZERO" );
/// Default Operator
using Operator = Operator_;
/// Number of warps present
using WarpCount = cutlass::gemm::GemmShape<
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
1
>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) &&
!(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
);
/// Number of threads per warp
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
// For Gmem load
static int const kElementsPerAccessA = 128 / sizeof_bits<ElementA>::value;
static int const kElementsPerAccessB = 128 / sizeof_bits<ElementB>::value;
//
// Shared memory layouts
//
using SmemLayoutA = layout::RowMajor;
using SmemLayoutB = layout::RowMajor;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kN, 1>, // Set kStrided = 1 because activation shape is runtime value.
kThreads,
kElementsPerAccessA
>;
/// ThreadMap of iterator A
using SmemThreadMapA = IteratorThreadMapA;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv<
MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value
ElementA,
SmemLayoutA,
0,
SmemThreadMapA, // was IteratorThreadMapA
true // Dynamic iterations.
>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kN, FilterShape::kCount>,
kThreads,
kElementsPerAccessB
>;
/// Transpose the ThreadMap of iterator B
using SmemThreadMapB = IteratorThreadMapB;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv<
MatrixShape<FilterShape::kCount, Shape::kN>,
ElementB,
SmemLayoutB,
0,
SmemThreadMapB, // was IteratorThreadMapB
false // static iterations.
>;
//
// Warp-level matrix multiply operator
//
// Groups per threads
// Fp32: 2 groups
// Fp16: 2 groups
static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4;
// Define the warp-level op
static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize);
static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN;
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
"WarpShape must be divisible by ThreadTile shape.");
// Get output P, Q per thread
static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kP;
static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kQ;
static const int LaneLayout = 1;
static const int numElementsB = kLaneAccessSizeB / sizeof_bits<ElementB>::value;
static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN);
// Define the output tile computed by each thread
using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>;
// Fetch the channel with same access size
static const int LaneM = LaneN;
// No paddings
static int const kPaddingM = 0;
static int const kPaddingN = 0;
static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN),
"Padding must be divisible by Lane");
// these should have max of thread tile also
using LaneMmaShape = cutlass::gemm::GemmShape<
LaneM,
LaneN,
1>;
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
LaneMmaShape
>;
using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt<
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<>
FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<>
ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
ElementA, /// Data type of A elements
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
ElementB, /// Data type of B elements
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
ElementC, /// Element type of C matrix
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy)
>;
/// Policy used to define MmaPipelined
using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<
MmaWarpSimt,
MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
IteratorThreadMapA,
IteratorThreadMapB,
WarpCount::kK
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: row-major
/// B: row-major
/// Operator: simt class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of threadblock-scoped output tile (concept: TensorNHWCShape)
typename ThreadBlockOutputShape_,
/// Shape of filter shape per threadblock
typename FilterShape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Size of a warp-scoped per thread access
int kLaneAccessSizeA_,
/// Number of stages
int Stages_,
/// Operation performed by GEMM
typename Operator_,
/// Stride ( MatrixShape<Height, Width> )
typename StrideShape_,
/// Dilation ( MatrixShape<Height, Width> )
typename DilationShape_,
/// Activation Shape loaded by threadblock
typename ActivationShape_>
struct DepthwiseDirectConvMmaCoreWithLaneAccessSize<Shape_,
ThreadBlockOutputShape_,
FilterShape_,
WarpShape_,
cutlass::gemm::GemmShape<1, 1, 1>,
ElementA_,
layout::RowMajor,
ElementB_,
layout::ColumnMajor,
ElementC_,
LayoutC_,
arch::OpClassSimt,
kLaneAccessSizeA_,
128,
Stages_,
Operator_,
IteratorAlgorithm::kFixedStrideDilation,
StrideShape_,
DilationShape_,
ActivationShape_> {
using Shape = Shape_;
using FilterShape = FilterShape_;
using WarpShape = WarpShape_;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = ElementB_;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using OperatorClass = arch::OpClassSimt;
using StrideShape = StrideShape_;
using DilationShape = DilationShape_;
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
using ActivationShape = ActivationShape_;
static int const kLaneAccessSizeB = 128;
// Divisility requirements
static_assert( kLaneAccessSizeB > 0,
"Size of a warp-scoped per thread access should be larger then ZERO" );
/// Default Operator
using Operator = Operator_;
/// Number of warps present
using WarpCount = cutlass::gemm::GemmShape<
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
1
>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) &&
!(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
);
/// Number of threads per warp
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
// For Gmem load
static int const kElementsPerAccessA = 128 / sizeof_bits<ElementA>::value;
static int const kElementsPerAccessB = 128 / sizeof_bits<ElementB>::value;
//
// Shared memory layouts
//
using SmemLayoutA = layout::RowMajor;
using SmemLayoutB = layout::RowMajor;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<ActivationShape::kC, ActivationShape::kNHW>,
kThreads,
kElementsPerAccessA
>;
/// ThreadMap of iterator A
using SmemThreadMapA = IteratorThreadMapA;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv<
MatrixShape<ActivationShape::kNHW, ActivationShape::kC>,
ElementA,
SmemLayoutA,
0,
SmemThreadMapA, // was IteratorThreadMapA
false // static iterations.
>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<Shape::kN, FilterShape::kCount>,
kThreads,
kElementsPerAccessB
>;
/// Transpose the ThreadMap of iterator B
using SmemThreadMapB = IteratorThreadMapB;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv<
MatrixShape<FilterShape::kCount, Shape::kN>,
ElementB,
SmemLayoutB,
0,
SmemThreadMapB, // was IteratorThreadMapB
false // static iterations.
>;
//
// Warp-level matrix multiply operator
//
// Groups per threads
// Fp32: 2 groups
// Fp16: 2 groups
static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4;
// Define the warp-level op
static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize);
static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN;
static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kP;
static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kQ;
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
"WarpShape must be divisible by ThreadTile shape.");
static const int LaneLayout = 1;
static const int numElementsB = kLaneAccessSizeB / sizeof_bits<ElementB>::value;
static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN);
// Define the output tile computed by each thread
using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>;
// Fetch the channel with same access size
static const int LaneM = LaneN;
// No paddings
static int const kPaddingM = 0;
static int const kPaddingN = 0;
static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN),
"Padding must be divisible by Lane");
// these should have max of thread tile also
using LaneMmaShape = cutlass::gemm::GemmShape<
LaneM,
LaneN,
1>;
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
LaneMmaShape
>;
using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt<
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<>
FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<>
ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
ElementA, /// Data type of A elements
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
ElementB, /// Data type of B elements
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
ElementC, /// Element type of C matrix
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy)
IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type
StrideShape, /// Stride ( MatrixShape<Height, Width> )
DilationShape, /// Dilation ( MatrixShape<Height, Width> )
ActivationShape /// Activation Shape loaded by threadblock
>;
/// Policy used to define MmaPipelined
using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<
MmaWarpSimt,
MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
IteratorThreadMapA,
IteratorThreadMapB,
WarpCount::kK
>;
};
} // namespace threadblock
} // namespace conv
} // namespace cutlass

View File

@ -165,7 +165,29 @@ struct StridedDgradIdentityThreadblockSwizzle :
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Threadblock swizzling function for GEMMs
template <int N = 1, int Output_N = 1, int Output_P = 1, int Output_Q = 1>
struct DepthwiseDirect2dConvIdentityThreadblockSwizzle
: public gemm::threadblock::GemmIdentityThreadblockSwizzle<N> {
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvIdentityThreadblockSwizzle() {}
/// Returns the shape of the problem in units of logical tiles
CUTLASS_HOST_DEVICE
gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator,
cutlass::conv::Conv2dProblemSize const &problem_size,
gemm::GemmCoord tile_size,
int split_k_slices) const {
gemm::GemmCoord implicit_gemm_problem_size =
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);
return gemm::GemmCoord(1,
(implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(),
split_k_slices);
}
};
} // namespace threadblock
} // namespace gemm
} // namespace conv
} // namespace cutlass

View File

@ -42,6 +42,9 @@
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/thread/mma.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/thread/depthwise_mma.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_simt_policy.h"
@ -91,7 +94,7 @@ class MmaDepthwiseSimt
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_; // < 64, 16 , 8>
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
@ -156,8 +159,223 @@ public:
MmaDepthwiseSimt():Base() {}
};
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
typename FilterShape_,
/// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<>
typename ThreadOutputShape_,
/// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
typename ThreadBlockOutputShape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Iterator algo type
conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic,
/// Stride ( MatrixShape<Height, Width> )
typename StrideShape_ = cutlass::MatrixShape<-1, -1>,
/// Dilation ( MatrixShape<Height, Width> )
typename DilationShape_ = cutlass::MatrixShape<-1, -1>,
/// Activation Shape loaded by threadblock
typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>,
/// Number of partitions along K dimension
int PartitionsK = 1,
/// Complex transformation on operand A
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex transformation on operand B
ComplexTransform TransformB = ComplexTransform::kNone,
/// Used for partial specialization
typename Enable = bool>
class MmaDepthwiseDirectConvSimt {
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
using FilterShape = FilterShape_;
/// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<>
using ThreadOutputShape = ThreadOutputShape_;
/// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Iterator algo type
static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_;
/// Stride ( MatrixShape<Height, Width> )
using StrideShape = StrideShape_;
/// Dilation ( MatrixShape<Height, Width> )
using DilationShape = DilationShape_;
/// Activation Shape loaded by threadblock
using ActivationShape = ActivationShape_;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassSimt;
/// Hard-coded for now
using ArchTag = arch::Sm50;
/// Complex transform on A operand
static ComplexTransform const kTransformA = TransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = TransformB;
static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value ||
platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) &&
platform::is_same< ElementA, int8_t >::value &&
platform::is_same< ElementB, int8_t >::value;
using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type;
/// Thread-level matrix multiply accumulate operator
using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<
cutlass::gemm::GemmShape<
Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread
Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread
1>,
ElementA,
ElementB,
ElementC,
arch::OpMultiplyAdd,
dp4a_type
>;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename ThreadMma::ArchMmaOperator;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Shape of the underlying instruction
using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>;
public:
/// Iterates over the A operand in memory
using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator<
MatrixShape<Shape::kM, Shape::kN>, // <output tile=(P*Q), output channels> per warp
FilterShape,
ThreadOutputShape,
ThreadBlockOutputShape,
cutlass::gemm::Operand::kA,
ElementA,
Policy,
IteratorAlgorithm,
StrideShape,
DilationShape,
ActivationShape,
PartitionsK,
Shape::kK
>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA = FragmentA;
/// Iterates over the B operand in memory
using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator<
MatrixShape<1, Shape::kN>,
cutlass::gemm::Operand::kB,
ElementB,
LayoutB,
Policy,
PartitionsK,
Shape::kK
>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed A tile
using TransformedFragmentB = FragmentB;
/// Iterates over the C operand in memory
using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator<
MatrixShape<Shape::kM, Shape::kN>,
cutlass::gemm::Operand::kC,
ElementC,
LayoutC,
Policy
>;
/// Storage for C tile
using FragmentC = typename ThreadMma::FragmentC;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaDepthwiseDirectConvSimt() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(
FragmentC &d,
FragmentA a,
FragmentB b,
FragmentC const &c, int group_idx = 0) const {
ThreadMma mma;
mma(d, a, b, c);
}
/// Transform the mma operands to the required types
CUTLASS_DEVICE
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {
//TODO: Implement this
dst_A = A;
dst_B = B;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace conv
} // namespace cutlass

View File

@ -40,6 +40,8 @@
#include "cutlass/tensor_ref.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/layout/matrix.h"
@ -250,6 +252,611 @@ private:
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Size of filter (concept: gemm::GemmShape<Depth, Height, Width>)
typename FilterShape_,
/// Size of the matrix to load (concept: MatrixShape)
typename ThreadOutputShape_,
/// Size of the matrix to load (concept: MatrixShape)
typename ThreadBlockOutputShape_,
/// Operand identity
cutlass::gemm::Operand Operand,
/// Data type of A elements
typename Element_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Iterator algo type
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
/// Stride ( MatrixShape<Height, Width> )
typename StrideShape = cutlass::MatrixShape<-1, -1>,
/// Dilation ( MatrixShape<Height, Width> )
typename DilationShape = cutlass::MatrixShape<-1, -1>,
/// Activation Shape loaded by threadblock
typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>,
/// Number of partitions along K dimension - used in sliced-K
int PartitionsK = 1,
/// Group Size along kPartition - used in sliced-K
int PartitionGroupSize = 1>
class DepthwiseDirect2dConvSimtTileIterator;
/// Specialization for A operands of row-major layouts
///
/// Concept: MutableRandomAccessContiguousTileIteratorConcept
///
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Size of filter (concept: gemm::GemmShape<Depth, Height, Width>)
typename FilterShape_,
/// Size of the matrix to load (concept: TensorNHWC)
typename ThreadOutputShape_,
/// Size of the matrix to load (concept: TensorNHWC)
typename ThreadBlockOutputShape_,
/// Data type of A elements
typename Element_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Iterator algo type
conv::IteratorAlgorithm IteratorAlgorithm,
/// Stride ( MatrixShape<Height, Width> )
typename StrideShape,
/// Dilation ( MatrixShape<Height, Width> )
typename DilationShape,
/// Activation Shape loaded by threadblock
typename ActivationShape,
/// Number of partitions along K dimension - used in sliced-K
int PartitionsK,
/// Group Size along kPartition - used in sliced-K
int PartitionGroupSize>
class DepthwiseDirect2dConvSimtTileIterator<Shape_,
FilterShape_,
ThreadOutputShape_,
ThreadBlockOutputShape_,
cutlass::gemm::Operand::kA,
Element_,
Policy_,
IteratorAlgorithm,
StrideShape,
DilationShape,
ActivationShape,
PartitionsK,
PartitionGroupSize> {
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = Shape_;
/// Shape of filter (concept: gemm::GemmShape<Depth, Height, Width>)
using FilterShape = FilterShape_;
/// Shape of tile to load (concept: TensorNHWC)
using ThreadOutputShape = ThreadOutputShape_;
/// Shape of tile to load (concept: TensorNHWC)
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
/// Operand tag
static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA;
/// Element type
using Element = Element_;
/// Layout of policy
using Layout = layout::RowMajor;
/// Decomposition of elements among threads
using Policy = Policy_;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
//
// Derived quantities
//
static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
"The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
// Thread-level shape of a fragment
using ThreadShape = MatrixShape<
ThreadOutputShape::kNHW, // Output tile shape Computed by current threads
ThreadOutputShape::kC
>;
static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),
"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
/// Number of individual loads
using Iterations = MatrixShape<
ThreadShape::kRow,
ThreadShape::kColumn / Policy::LaneMmaShape::kN
>;
using ThreadTileCount = MatrixShape<
ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
ThreadBlockOutputShape::kW / ThreadOutputShape::kW
>;
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, ThreadShape::kCount>;
protected:
/// Internal reference
cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kN>, layout::RowMajor> ref_;
int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn];
int iterator_r_;
int iterator_s_;
int iterator_offset_;
int inc_next_s_ ;
int inc_next_r_ ;
MatrixCoord lane_offset_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator() { }
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator(
TensorRef ref,
int lane_id
) {
// compute offset based on thread ID and lane layout
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
// Set channel offset
lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN);
ref.add_coord_offset(lane_offset_);
ref_.reset(reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(ref.data()),
ref.stride(0) / Policy::LaneMmaShape::kN);
iterator_r_ = 0;
iterator_s_ = 0;
iterator_offset_ = 0;
}
/// Adds a pointer offset to internal pointer(s) to advance through memory
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) {
ref_.add_pointer_offset(offset);
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
template<typename Params>
CUTLASS_HOST_DEVICE
void setup_initial_status(Params const& params) {
inc_next_s_ = params.inc_next[0];
inc_next_r_ = params.inc_next[1];
// Get base HW offset of current threads
int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC);
int base_p_ =
(threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH;
int base_q_ =
(threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW;
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < ThreadOutputShape::kH; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int q = 0; q < ThreadOutputShape::kW; ++q) {
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < Iterations::kColumn; ++col) {
int base_w = (base_q_ + q) * params.stride[0];
int base_h = (base_p_ + p) * params.stride[1];
int offset = base_h * params.activation_tile_w + base_w;
activation_offset[p][q][col] = offset;
}
}
}
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
// Set warp row and col start
lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()});
return *this;
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
void advance(int32_t pointer_offset) {
ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN);
iterator_s_ = 0;
iterator_r_ = 0;
iterator_offset_ = 0;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator &operator++() {
++iterator_s_;
if (iterator_s_ < FilterShape::kColumn) {
iterator_offset_ += inc_next_s_;
return *this;
}
iterator_s_ = 0;
++iterator_r_;
if (iterator_r_ < FilterShape::kRow) {
iterator_offset_ += inc_next_r_;
return *this;
}
iterator_r_ = 0;
iterator_offset_ = 0;
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator & operator--() {
// Do nothing
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator. (vector loads)
CUTLASS_HOST_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < ThreadOutputShape::kH; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int q = 0; q < ThreadOutputShape::kW; ++q) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Iterations::kColumn; ++n) {
void const *ptr = ref_.data() +
ref_.offset({activation_offset[p][q][n] + (iterator_offset_),
n * Policy::WarpShape::kColumn}) +
pointer_offset / Policy::LaneMmaShape::kN;
arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr);
}
}
}
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
/// Stores a fragment to memory at the location pointed to by the iterator
CUTLASS_HOST_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
// Do nothing at present.
}
/// Stores a fragment to memory at the location pointed to by the iterator
CUTLASS_HOST_DEVICE
void store(Fragment const &frag, Index pointer_offset) const {
store_with_pointer_offset(frag, 0);
}
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
// no operation here
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Specialization for A operands of row-major layouts
///
/// Concept: MutableRandomAccessContiguousTileIteratorConcept
///
template <
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Size of filter (concept: gemm::GemmShape<Depth, Height, Width>)
typename FilterShape_,
/// Size of the matrix to load (concept: TensorNHWC)
typename ThreadOutputShape_,
/// Size of the matrix to load (concept: TensorNHWC)
typename ThreadBlockOutputShape_,
/// Data type of A elements
typename Element_,
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
typename Policy_,
/// Stride ( MatrixShape<Height, Width> )
typename StrideShape_,
/// Dilation ( MatrixShape<Height, Width> )
typename DilationShape_,
/// Activation Shape loaded by threadblock
typename ActivationShape_,
/// Number of partitions along K dimension - used in sliced-K
int PartitionsK,
/// Group Size along kPartition - used in sliced-K
int PartitionGroupSize>
class DepthwiseDirect2dConvSimtTileIterator<Shape_,
FilterShape_,
ThreadOutputShape_,
ThreadBlockOutputShape_,
cutlass::gemm::Operand::kA,
Element_,
Policy_,
IteratorAlgorithm::kFixedStrideDilation,
StrideShape_,
DilationShape_,
ActivationShape_,
PartitionsK,
PartitionGroupSize> {
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = Shape_;
/// Shape of filter (concept: gemm::GemmShape<Depth, Height, Width>)
using FilterShape = FilterShape_;
/// Shape of tile to load (concept: TensorNHWC)
using ThreadOutputShape = ThreadOutputShape_;
/// Shape of tile to load (concept: TensorNHWC)
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
/// Stride ( MatrixShape<Height, Width> )
using StrideShape = StrideShape_;
/// Dilation ( MatrixShape<Height, Width> )
using DilationShape = DilationShape_;
/// Activation Shape loaded by threadblock
using ActivationShape = ActivationShape_;
/// Operand tag
static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA;
/// Element type
using Element = Element_;
/// Layout of policy
using Layout = layout::RowMajor;
/// Decomposition of elements among threads
using Policy = Policy_;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
//
// Derived quantities
//
static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
"The warp-level GEMM M size must be divisible by the number of threads arranged "
"along the M dimension.");
static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
static_assert(Shape::kRow / Policy::WarpShape::kRow > 0,
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
// Activations loaded by threadblock
static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow +
(FilterShape::kRow - 1) * DilationShape::kRow + 1;
static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn +
(FilterShape::kColumn - 1) * DilationShape::kColumn + 1;
using ThreadActivationShape = cutlass::conv::
TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>;
// Thread-level shape of a fragment
using ThreadShape =
MatrixShape<ThreadOutputShape::kNHW,
ThreadOutputShape::kC>;
static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),
"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
/// Number of individual loads
using Iterations =
MatrixShape<ThreadShape::kRow, ThreadShape::kColumn / Policy::LaneMmaShape::kN>;
using ThreadTileCount = MatrixShape<ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
ThreadBlockOutputShape::kW / ThreadOutputShape::kW>;
/// Fragment object holding a thread's part of a tile
using Fragment = Array<Element, ThreadShape::kCount>;
protected:
/// Internal reference
cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kN>, layout::RowMajor> ref_;
Array<Element, Policy::LaneMmaShape::kN>
activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn];
int iterator_r_;
int iterator_s_;
MatrixCoord lane_offset_;
public:
/// Default ctor constructs null iterator
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator() {}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) {
// compute offset based on thread ID and lane layout
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
// Set channel offset
lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN);
ref.add_coord_offset(lane_offset_);
ref_.reset(reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(ref.data()),
ref.stride(0) / Policy::LaneMmaShape::kN);
iterator_r_ = 0;
iterator_s_ = 0;
}
/// Adds a pointer offset to internal pointer(s) to advance through memory
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) {
ref_.add_pointer_offset(offset);
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
template <typename Params>
CUTLASS_HOST_DEVICE void setup_initial_status(
Params const &params) {
// Get base HW offset of current threads
int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC);
int base_h =
(threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow;
int base_w =
(threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn;
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < ThreadActivationShape::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < ThreadActivationShape::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < Iterations::kColumn; ++col) {
int offset = (base_h + h) * ActivationShape::kW + (base_w + w);
void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn});
arch::shared_load(activation[h][w][col], ptr);
}
}
}
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
// Set warp row and col start
lane_offset_ =
MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()});
return *this;
}
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
CUTLASS_HOST_DEVICE
void advance(int32_t pointer_offset) {
ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN);
iterator_s_ = 0;
iterator_r_ = 0;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator &operator++() {
++iterator_s_;
if (iterator_s_ < FilterShape::kColumn) {
return *this;
}
iterator_s_ = 0;
++iterator_r_;
if (iterator_r_ < FilterShape::kRow) {
return *this;
}
iterator_r_ = 0;
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_HOST_DEVICE
DepthwiseDirect2dConvSimtTileIterator &operator--() {
// Do nothing
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator. (vector loads)
CUTLASS_HOST_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < ThreadOutputShape::kH; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int q = 0; q < ThreadOutputShape::kW; ++q) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Iterations::kColumn; ++n) {
const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow;
const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn;
dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n];
}
}
}
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); }
/// Stores a fragment to memory at the location pointed to by the iterator
CUTLASS_HOST_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
// Do nothing at present.
}
/// Stores a fragment to memory at the location pointed to by the iterator
CUTLASS_HOST_DEVICE
void store(Fragment const &frag, Index pointer_offset) const {
store_with_pointer_offset(frag, 0);
}
CUTLASS_DEVICE
void set_kgroup_index(int k_group) {
// no operation here
}
};
} // namespace warp
} // namespace gemm
} // namespace conv
} // namespace cutlass

View File

@ -100,12 +100,21 @@ public:
}
}
/// Constructs from some other Coord
template <int R, typename I, typename L>
CUTLASS_HOST_DEVICE
Coord(Coord<R, I, L> other) {
for (int i = 0; i < kRank; ++i) {
idx[i] = other[i];
}
}
/// Returns a slice of the Coord which may be larger or smaller in rank
/// than this.
template <int Slice>
CUTLASS_HOST_DEVICE
Coord<Slice> slice(int start = 0, Index identity = 0) const {
Coord<Slice> result;
Coord<Slice, Index, LongIndex> slice(int start = 0, Index identity = 0) const {
Coord<Slice, Index, LongIndex> result;
for (int i = 0; i < Slice; ++i) {
if (i + start < kRank) {
result[i] = idx[i + start];

View File

@ -59,7 +59,9 @@ inline std::ostream &operator<<(std::ostream &out, dim3 d) {
/// Output operator for CUDA built-in error type
inline std::ostream &operator<<(std::ostream &out, cudaError_t error) {
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
return out << cudaGetErrorString(error);
#endif
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -252,8 +254,9 @@ namespace conv {
inline
std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) {
out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl
<< "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl
<< "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl
<< "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl
<< "groups: (" << problem.groups << ")" << std::endl
<< "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl
<< "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl
<< "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl

View File

@ -57,6 +57,23 @@ void Kernel(typename Operator::Params params) {
op(params, *shared_storage);
}
/// Generic CUTLASS kernel template.
template <typename Operator>
__global__
void Kernel2(typename Operator::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int SharedStorageBase[];
// Declare pointer to dynamic shared memory.
typename Operator::SharedStorage *shared_storage =
reinterpret_cast<typename Operator::SharedStorage *>(SharedStorageBase);
Operator::invoke(params, *shared_storage);
}
////////////////////////////////////////////////////////////////////////////////
} /// namespace cutlass

View File

@ -104,15 +104,15 @@ struct Identity {
template <typename T, int N>
struct Identity<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
return rhs;
Array<T, N> operator()(Array<T, N> const &value) const {
return value;
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};
@ -183,7 +183,7 @@ struct LeakyReLU {
Params():
LinearCombinationGenericParams<T>(),
leaky_alpha(T(1)) {}
CUTLASS_HOST_DEVICE
Params(
T alpha,
@ -228,21 +228,21 @@ struct LeakyReLU<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
Array<T, N> operator()(Array<T, N> const &value, T const & alpha_recip) const {
Array<T, N> y;
LeakyReLU<T> leaky_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < int(rhs.size()); ++i) {
y[i] = leaky_op(rhs[i], alpha_recip);
for (int i = 0; i < int(value.size()); ++i) {
y[i] = leaky_op(value[i], alpha_recip);
}
return y;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs, params_.leaky_alpha);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value, params_.leaky_alpha);
}
};
@ -265,13 +265,13 @@ struct Tanh {
template <typename T, int N>
struct Tanh<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
Tanh<T> tanh_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = tanh_op(rhs[i]);
y[i] = tanh_op(value[i]);
}
return y;
@ -280,8 +280,8 @@ struct Tanh<Array<T, N> > {
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};
@ -299,8 +299,8 @@ struct Tanh<Array<half_t, N>> {
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};
@ -323,13 +323,13 @@ struct Sigmoid {
template <typename T, int N>
struct Sigmoid<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
Sigmoid<T> sigmoid_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = sigmoid_op(rhs[i]);
y[i] = sigmoid_op(value[i]);
}
return y;
@ -338,8 +338,8 @@ struct Sigmoid<Array<T, N> > {
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};
@ -398,17 +398,17 @@ struct SiLu {
template <typename T, int N>
struct SiLu<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> operator()(Array<T, N> const &value) const {
Sigmoid<Array<T, N>> sigmoid_op;
multiplies<Array<T, N>> mul;
return mul(rhs, sigmoid_op(rhs));
return mul(value, sigmoid_op(value));
}
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};
@ -458,13 +458,13 @@ struct HardSwish<float> {
template <typename T, int N>
struct HardSwish<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
HardSwish<T> hardswish_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = hardswish_op(rhs[i]);
y[i] = hardswish_op(value[i]);
}
return y;
@ -483,13 +483,13 @@ struct HardSwish<Array<half_t, N> > {
using T = half_t;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> operator()(Array<T, N> const &value) const {
minimum<Array<T, N> > mn;
maximum<Array<T, N> > mx;
multiplies<Array<T, N> > mul;
plus<Array<T, N> > add;
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(0.16666667f));
}
using Params = LinearCombinationGenericParams<T>;
@ -561,13 +561,13 @@ struct GELU<double> {
template <typename T, int N>
struct GELU<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
GELU<T> gelu_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = gelu_op(rhs[i]);
y[i] = gelu_op(value[i]);
}
return y;
@ -576,8 +576,8 @@ struct GELU<Array<T, N> > {
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};
@ -601,7 +601,6 @@ struct GELU_taylor {
T operator()(T const &scalar, Params const &params_) const {
return this->operator()(scalar);
}
};
template <int N>
@ -632,8 +631,8 @@ struct GELU_taylor<Array<half_t, N> > {
using Params = LinearCombinationGenericParams<half_t>;
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<half_t, N> operator()(Array<half_t, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};
@ -641,13 +640,13 @@ template <typename T, int N>
struct GELU_taylor<Array<T, N> > {
static const bool kIsHeavy=true;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y;
GELU_taylor<T> gelu_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = gelu_op(rhs[i]);
y[i] = gelu_op(value[i]);
}
return y;
@ -656,8 +655,8 @@ struct GELU_taylor<Array<T, N> > {
using Params = LinearCombinationGenericParams<T>;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
return this->operator()(rhs);
Array<T, N> operator()(Array<T, N> const &value, Params const &params_) const {
return this->operator()(value);
}
};

View File

@ -78,6 +78,9 @@ public:
using ElementwiseOp = ElementwiseOp_;
using BinaryOp = BinaryOp_;
// Indicates that this epilogue applies only one binary operation
static bool const kIsSingleSource = true;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementOutput, kElementsPerAccess>;

View File

@ -223,6 +223,9 @@ public:
using ElementwiseOp = ReLu<ElementCompute>;
using BinaryOp = plus<ElementCompute>;
// Indicates that this epilogue applies only one binary operation
static bool const kIsSingleSource = true;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementOutput, kElementsPerAccess>;

View File

@ -37,6 +37,7 @@
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/scale_type.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
/*! \file
\brief Epilogue functor specialized for residual blocks in deep neural network.
\brief Epilogue functor specialized for residual blocks in deep neural networks.
*/
#pragma once
@ -45,14 +45,24 @@ namespace cutlass {
namespace epilogue {
namespace thread {
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
namespace detail {
/// Dummy class used to designate that the second binary operator in the epilogue is unsued
template <typename T>
class NoOp {};
}
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
template <typename ElementOutput_, typename ElementAccumulator_,
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp_,
template <typename T> class UnaryOp_>
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_ = detail::NoOp>
class LinearCombinationResidualBlock {
public:
static bool const kIsSingleSource = false;
using ElementOutput = ElementC_;
using ElementC = ElementC_;
@ -62,7 +72,130 @@ public:
static int const kCount = kElementsPerAccess;
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
using BinaryOp = BinaryOp_<Array<ElementCompute, kCount>>;
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementC, kElementsPerAccess>;
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
using ElementZ = ElementOutput_;
using ElementT = ElementZ;
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
using FragmentT = Array<ElementT, kElementsPerAccess>;
static bool const kIsHeavy = true;
static bool const kStoreZ = true;
static bool const kStoreT = false;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales residual input
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
CUTLASS_HOST_DEVICE
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha, ElementCompute beta)
: alpha(alpha), beta(beta) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
};
private:
ElementCompute alpha_;
ElementCompute beta_;
bool skip_elementwise_;
public:
/// Constructor from Params
CUTLASS_HOST_DEVICE
LinearCombinationResidualBlock(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
skip_elementwise_ = false;
}
/// The "source" tensor corresponds to the residual input
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return true; }
/// Functionally required for serial reduction in the epilogue
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
skip_elementwise_ = true;
}
}
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
FragmentC const &residual1, FragmentC const &residual2,
FragmentCompute const &bias) const {
UnaryOp unary_op;
BinaryOp1 binary_op1;
BinaryOp2 binary_op2;
ActivationOp activation;
FragmentCompute tmp_Accum =
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
FragmentCompute tmp_residual1 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
FragmentCompute tmp_residual2 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
FragmentCompute z =
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
}
/// Should never be called
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
FragmentCompute const &) const {}
};
/// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
template <typename ElementOutput_, typename ElementAccumulator_,
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_>
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
ElementCompute_, ElementC_, ElementsPerAccess,
ActivationOp_, BinaryOp1_, UnaryOp_,
detail::NoOp> {
public:
static bool const kIsSingleSource = true;
using ElementOutput = ElementC_;
using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
using BinaryOp = BinaryOp1_<Array<ElementCompute, kCount>>;
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;

View File

@ -1,197 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue functor specialized for residual blocks in deep neural network.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
// or form UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
template <typename ElementOutput_, typename ElementAccumulator_,
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_=BinaryOp1_>
class LinearCombinationResidualBlockV2 {
public:
using ElementOutput = ElementC_;
using ElementC = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementC, kElementsPerAccess>;
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
using ElementZ = ElementOutput_;
using ElementT = ElementZ;
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
using FragmentT = Array<ElementT, kElementsPerAccess>;
static bool const kIsHeavy = true;
static bool const kStoreZ = true;
static bool const kStoreT = false;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales residual input
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
CUTLASS_HOST_DEVICE
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha, ElementCompute beta)
: alpha(alpha), beta(beta) {}
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
};
private:
ElementCompute alpha_;
ElementCompute beta_;
bool skip_elementwise_;
public:
/// Constructor from Params
CUTLASS_HOST_DEVICE
LinearCombinationResidualBlockV2(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
skip_elementwise_ = false;
}
/// The "source" tensor corresponds to the residual input
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return true; }
/// Functionally required for serial reduction in the epilogue
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
if (k_partition != k_partition_count - 1) {
skip_elementwise_ = true;
}
}
/// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
FragmentC const &residual,
FragmentCompute const &bias) const {
UnaryOp unary_op;
BinaryOp1 binary_op;
ActivationOp activation;
FragmentCompute tmp_Accum =
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
FragmentCompute tmp_residual =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
FragmentCompute z =
binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
}
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
FragmentC const &residual1, FragmentC const &residual2,
FragmentCompute const &bias) const {
UnaryOp unary_op;
BinaryOp1 binary_op1;
BinaryOp2 binary_op2;
ActivationOp activation;
FragmentCompute tmp_Accum =
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
FragmentCompute tmp_residual1 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
FragmentCompute tmp_residual2 =
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
FragmentCompute z =
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
frag_Z = convert_z(result_Z);
}
/// Should never be called
CUTLASS_HOST_DEVICE
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
FragmentCompute const &) const {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -58,12 +58,16 @@
#include "cutlass/epilogue/warp/fragment_iterator_simt.h"
#include "cutlass/epilogue/warp/tile_iterator_simt.h"
#include "cutlass/epilogue/threadblock/default_thread_map_simt.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/epilogue_depthwise.h"
#include "cutlass/layout/permute.h"
@ -314,6 +318,100 @@ struct DefaultEpilogueSimtAffineRankN {
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for SimtOps.
template <typename Shape_, // ThreadBlock Shape
typename WarpMmaSimt_, // mma_depthwise_simt
typename OutputOp_,
int ElementsPerAccess_,
typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>,
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> >
struct DefaultDirectConvEpilogueSimt {
using Shape = Shape_;
using WarpMmaSimt = WarpMmaSimt_;
using WarpShape = typename WarpMmaSimt::Shape;
using OutputOp = OutputOp_;
using ThreadOutputShape = ThreadOutputShape_;
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
static int const kElementsPerAccess = ElementsPerAccess_;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaSimt::LayoutC;
using ElementAccumulator = typename WarpMmaSimt::ElementC;
/// Number of threads total
using WarpCount = gemm::GemmShape<
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN
>;
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
static int const kThreads = WarpCount::kCount * kWarpSize;
//
// Thread map
//
using OutputTileThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<ThreadBlockOutputShape::kC, ThreadBlockOutputShape::kNHW>,
kThreads,
kElementsPerAccess
>;
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<
OutputTileThreadMap,
ElementOutput,
ThreadOutputShape,
ThreadBlockOutputShape
>;
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
typename WarpMmaSimt::Shape,
typename WarpMmaSimt::ThreadMma,
layout::RowMajor,
typename WarpMmaSimt::Policy
>;
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<
typename WarpMmaSimt::Shape,
ThreadOutputShape,
ThreadBlockOutputShape,
typename WarpMmaSimt::ThreadMma,
ElementAccumulator,
layout::RowMajor,
typename WarpMmaSimt::Policy
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLiner<
OutputTileThreadMap,
ElementAccumulator
>;
/// Hard-coded padding elements added
using Padding = typename WarpTileIterator::Padding;
//
// Define the epilogue
//
using Epilogue = cutlass::epilogue::threadblock::EpilogueDepthwise<
Shape,
ThreadOutputShape,
ThreadBlockOutputShape,
WarpMmaSimt,
OutputTileIterator,
AccumulatorFragmentIterator,
WarpTileIterator,
SharedLoadIterator,
OutputOp,
Padding
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -293,6 +293,141 @@ struct DefaultIteratorsTensorOp<
static int const kFragmentsPerIteration = 1;
};
/// Partial specialization for float_e4m3_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
/// Threadblock::kN = 256 still has bank conflicts.
template <
int ElementsPerAccess,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
cutlass::float_e4m3_t,
float,
ElementsPerAccess,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using ElementOutput = cutlass::float_e4m3_t;
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),
"ElementsPerAccess needs to be 16 or 8.");
using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
WarpShape,
InstructionShape,
float,
32,
cutlass::sizeof_bits<ElementOutput>::value,
ElementsPerAccess,
8
>;
using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape,
InstructionShape,
float,
layout::RowMajor
>;
using WarpTileIterator = typename platform::conditional<
(ThreadblockShape::kN == 256),
WarpTileIteratorNotMixed,
WarpTileIteratorMixed>::type;
using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
float,
32,
cutlass::sizeof_bits<ElementOutput>::value,
ElementsPerAccess,
8
>;
using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
ThreadMap,
float
>;
using SharedLoadIterator = typename platform::conditional<
(ThreadblockShape::kN == 256),
SharedLoadIteratorNotMixed,
SharedLoadIteratorMixed>::type;
static int const kFragmentsPerIteration = 1;
};
/// Partial specialization for float_e5m2_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
/// Threadblock::kN = 256 still has bank conflicts.
template <
int ElementsPerAccess,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
cutlass::float_e5m2_t,
float,
ElementsPerAccess,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using ElementOutput = cutlass::float_e5m2_t;
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),
"ElementsPerAccess needs to be 16 or 8.");
using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
WarpShape,
InstructionShape,
float,
32,
cutlass::sizeof_bits<ElementOutput>::value,
ElementsPerAccess,
8
>;
using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape,
InstructionShape,
float,
layout::RowMajor
>;
using WarpTileIterator = typename platform::conditional<
(ThreadblockShape::kN == 256),
WarpTileIteratorNotMixed,
WarpTileIteratorMixed>::type;
using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
float,
32,
cutlass::sizeof_bits<ElementOutput>::value,
ElementsPerAccess,
8
>;
using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
ThreadMap,
float
>;
using SharedLoadIterator = typename platform::conditional<
(ThreadblockShape::kN == 256),
SharedLoadIteratorNotMixed,
SharedLoadIteratorMixed>::type;
static int const kFragmentsPerIteration = 1;
};
} // namespace detail
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,177 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for TensorOps.
template <
typename Shape,
typename WarpMmaTensorOp,
int PartitionsK,
typename ElementOutput,
typename ElementTensor,
typename ElementVector,
typename OutputOp,
int ElementsPerAccess,
bool ScatterD = false
>
struct DefaultEpilogueWithBroadcastTensorOpV2 {
/// Use defaults related to the existing epilogue
using Base = DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputOp,
ElementsPerAccess
>;
//
// Stores the result z = (y = GEMM(A, B, C), broadcast)
//
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementOutput,
ScatterD
>;
//
// Additional tensor tile iterator - stores t = Elementwise(z)
//
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementTensor
>;
/// Define the epilogue
using Epilogue = EpilogueWithBroadcastV2<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputTileIterator,
TensorTileIterator,
ElementVector,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
OutputOp,
typename Base::Padding,
Base::kFragmentsPerIteration
>;
};
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for VoltaTensorOps.
template <
typename Shape,
typename WarpMmaTensorOp,
int PartitionsK,
typename ElementOutput,
typename ElementTensor,
typename ElementVector,
typename OutputOp,
int ElementsPerAccess
>
struct DefaultEpilogueWithBroadcastVoltaTensorOpV2 {
/// Use defaults related to the existing epilogue
using Base = DefaultEpilogueVoltaTensorOp<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputOp,
ElementsPerAccess
>;
//
// Stores the result z = (y = GEMM(A, B, C), broadcast)
//
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementOutput
>;
//
// Additional tensor tile iterator - stores t = Elementwise(z)
//
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
typename Base::OutputTileThreadMap,
ElementTensor
>;
/// Define the epilogue
using Epilogue = EpilogueWithBroadcastV2<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputTileIterator,
TensorTileIterator,
ElementVector,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
OutputOp,
typename Base::Padding
>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -34,6 +34,7 @@
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
The shared memory resource is time-sliced across warps.
*/
#pragma once
@ -59,8 +60,9 @@
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
#include "cutlass/util/index_sequence.h"
////////////////////////////////////////////////////////////////////////////////
@ -68,6 +70,7 @@ namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
@ -85,27 +88,39 @@ template <
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
>
class Epilogue :
class Epilogue :
public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
FragmentsPerPartition>,
public EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>
{
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using BaseStreamK = EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
@ -115,15 +130,23 @@ public:
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
/// Number of warps per block
using WarpCount = typename Base::WarpCount;
/// Number of threads per block
static int const kBlockThreads = 32 * WarpCount::kCount;
/// Per-thread accumulator tile type
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Numerical accumulation element type
using ElementAccumulator = typename WarpMmaOperator::ElementC;
/// Fragment type used by the accumulator tile's fragment iterator
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
@ -140,21 +163,20 @@ public:
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
/// Vector type used by the global output iterator
using OutputAccessType = Array<
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
/// Vector type used by the shared output iterator
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
@ -163,144 +185,177 @@ public:
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
/// Thread index in the threadblock
int thread_idx;
/// Warp index in the threadblock
int warp_idx;
public:
/// Constructor
CUTLASS_DEVICE
Epilogue(
typename Base::SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx)
typename Base::SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx) ///< Id of thread within warp
:
Base(shared_storage, thread_idx, warp_idx, lane_idx),
BaseStreamK(thread_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx),
thread_idx(thread_idx),
warp_idx(warp_idx)
{}
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
/// performing epilogue computations, writing to output
CUTLASS_DEVICE
void reduce(
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
ElementAccumulator *element_workspace,
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
{
// Redcuce peer accumulator fragments into one fragment
AccumulatorFragment accum_fragment;
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
// Store fragment to shared memory
this->warp_tile_iterator_.store(accum_fragment);
__syncthreads();
// Initialize/load source-fragment data
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
if (output_op.is_source_needed())
{
source_iterator += reduce_fragment_idx;
source_iterator.load(source_fragment);
}
// Load fragment from shared memory
typename SharedLoadIterator::Fragment aligned_accum_fragment;
shared_load_iterator_.load(aligned_accum_fragment);
// Add fragments shared by other k partitions
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
typename SharedLoadIterator::Fragment aligned_addend_fragment;
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_addend_fragment);
aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment);
}
}
// Compute the output result
typename OutputTileIterator::Fragment output_fragment;
// Apply the output operator
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
// Store the final result
destination_iterator += reduce_fragment_idx;
destination_iterator.store(output_fragment);
}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
if (!output_op.is_source_needed()) {
compute_source_not_needed_(output_op, destination_iterator, accumulators);
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
{
if (!output_op.is_source_needed())
{
source_iterator.clear_mask();
__syncthreads(); // Dummy (CUDA 11.0)
}
else {
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
}
}
private:
// Source-fragment data (zero-initialized for scenarios where the
// output operator allows us to skip loading it from global input)
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
template <class Seq>
struct acc2smem_source_not_needed;
// Iterator over warp-level accumulator fragment
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration)
{
//
// Convert and store fragment
//
__syncthreads();
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
{
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
this->warp_tile_iterator_.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {
(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<
cutlass::make_index_sequence<OutputTileIterator::kIterations /
Base::kFragmentsPerIteration>>::push(iter,
accum_fragment_iterator,
this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
__syncthreads();
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
{
// Load addend source fragment from global memory
source_iterator.load(source_fragment);
++source_iterator;
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
if (p < Base::kFragmentsPerIteration - 1)
{
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
}
else if (kPartitionsK > 1) {
else if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
@ -318,9 +373,7 @@ private:
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment[0]);
apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
//
// Store the final result
@ -336,170 +389,37 @@ private:
}
}
template<class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
) {
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
source_iterator.load(source_fragment);
++source_iterator;
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
if (kPartitionsK > 1) {
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
private:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment) {
OutputAccessType *output_frag_ptr =
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,191 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Basic subset of epilogue functionality for supporting StreamK decompositions
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/block_striped.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// StreamK epilogue functionality for cross-block accumulator fragment reduction
template <
typename Shape, ///< Shape of threadblock tile (concept: GemmShape)
int PartitionsK,
typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
typename AccumulatorFragmentIterator> ///< Fragment iterator selecting accumulators
class EpilogueBaseStreamK
{
protected:
/// The complete warp-level accumulator tile
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
/// Number of warps
using WarpCount = gemm::GemmShape<
Shape::kM / WarpMmaOperator::Shape::kM,
Shape::kN / WarpMmaOperator::Shape::kN, PartitionsK>;
/// Number of threads per block
static int const kBlockThreads = 32 * WarpCount::kCount;
/// Numerical accumulation element type
using ElementAccumulator = typename WarpMmaOperator::ElementC;
/// Fragment type used by the accumulator tile's fragment iterator
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
/// Block-striped transfer utility for sharing AccumulatorFragment
using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment, ElementAccumulator>;
/// Number of elements per fragment
static int const kFragmentElements = sizeof(AccumulatorFragment) / sizeof(ElementAccumulator);
public:
/// Number of fragments per accumulator tile
static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations;
/// Number of workspace accumulation elements shared per output tile
static int const kPeerAccumulators = WarpMmaOperator::Shape::kMN * WarpCount::kCount;
protected:
/// ElementAccumulator stride in the shared workspace between different peer blocks (two: each peer block can share accumulators for up to two tiles)
static const int kPeerStride = kPeerAccumulators * 2;
public:
/// Thread index in the threadblock
int thread_idx;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueBaseStreamK(
int thread_idx) ///< ID of a thread within the threadblock
:
thread_idx(thread_idx)
{}
/// Aggregates the accumulator sets shared by peer blocks in the global workspace
CUTLASS_DEVICE
void reduce(
AccumulatorFragment &accum_fragment, ///< [out] sum of all shared accumulator fragments for these peer partials
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
ElementAccumulator *element_workspace)
{
plus<AccumulatorFragment> add_fragments;
int accum_set_offset =
(peer_idx_begin * kPeerStride) +
(reduce_fragment_idx * kBlockThreads * kFragmentElements);
// Load first peer fragment
BlockStripedT::load(accum_fragment, element_workspace + accum_set_offset, this->thread_idx);
accum_set_offset += kPeerStride; // Move to next peer
accum_set_offset += kPeerAccumulators; // Move to non-starting accumulator set for peer
// Reduce additional peer fragments
#pragma unroll 2
while (accum_set_offset < peer_idx_end * kPeerStride)
{
AccumulatorFragment addend_fragment;
BlockStripedT::load(addend_fragment, element_workspace + accum_set_offset, this->thread_idx);
accum_set_offset += kPeerStride;
accum_fragment = add_fragments(accum_fragment, addend_fragment);
}
}
/// Shares the accumulator set with peers in the global workspace
CUTLASS_DEVICE
void share(
int peer_idx,
ElementAccumulator *element_workspace, ///< Output pointer for writing this block's accumulator set to
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
bool started_tile)
{
int accum_set_offset = peer_idx * kPeerStride;
if (!started_tile) {
// Move to non-starting accumulator set
accum_set_offset += kPeerAccumulators;
}
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kAccumulatorFragments; ++iter)
{
// Acquire reordered accumulator fragment
AccumulatorFragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
// Store accumulator fragment
BlockStripedT::store(element_workspace + accum_set_offset, accum_fragment, this->thread_idx);
accum_set_offset += (kFragmentElements * kBlockThreads);
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,335 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for Depthwise convoltuion
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
template <typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename ThreadOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
typename ThreadBlockOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
///< gemm::warp::MmaTensorOp)
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
typename OutputOp_, ///< Output operator
typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept:
///< MatrixShape)
>
class EpilogueDepthwise {
public:
using Shape = Shape_;
using WarpShape = typename WarpMmaOperator_::Shape;
using ThreadOutputShape = ThreadOutputShape_;
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
using WarpMmaOperator = WarpMmaOperator_;
using OutputTileIterator = OutputTileIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType =
Array<typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType =
Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount =
gemm::GemmShape<Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN>;
public:
static_assert(SharedLoadIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess,
"OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
/// Shared storage allocation needed by the epilogue
struct SharedStorage {
//
// Type definitions
//
/// Element type of shared memory
using Element = typename WarpTileIterator::Element;
/// Tensor reference to shared memory allocation
using TensorRef = typename WarpTileIterator::TensorRef;
/// Layout of shared memory allocation
using Layout = typename WarpTileIterator::Layout;
/// Logical shape of the shared memory tile written to by all warps.
using Shape = MatrixShape<ThreadBlockOutputShape::kNHW, ThreadBlockOutputShape::kC>;
/// Shape of the shared memory allocation for the epilogue
using StorageShape = MatrixShape<Shape::kRow, Shape::kColumn>;
//
// Data members
//
AlignedBuffer<Element, StorageShape::kCount> storage;
//
// Methods
//
/// Returns a pointer to the shared memory buffer
CUTLASS_DEVICE
Element *data() { return storage.data(); }
/// Returns a tensor reference to the shared memory buffer
CUTLASS_DEVICE
TensorRef reference() {
return TensorRef(storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
}
};
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
/// Stores a warp's fragment of accumulators to SMEM
WarpTileIterator warp_tile_iterator_;
LongIndex warp_offset;
int thread_idx;
int warp_idx;
int lane_idx;
int warp_m, warp_n; // warp coordinates within a cta
int tid_m, tid_n; // thread coordinates within a warp
public:
/// Constructor
CUTLASS_DEVICE
EpilogueDepthwise(SharedStorage &shared_storage, ///< Shared storage object
int thread_idx_, ///< ID of a thread within the threadblock
int warp_idx_, ///< ID of warp within threadblock
int lane_idx_ ///< Id of thread within warp
)
: thread_idx(thread_idx_),
warp_idx(warp_idx_),
lane_idx(lane_idx_),
shared_load_iterator_(shared_storage.reference(), thread_idx_),
warp_tile_iterator_(shared_storage.reference(), thread_idx_, lane_idx_) {}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in
///< units of threadblock tiles)
const int smem_base_offset) { ///< SMEM base offset for epilogue operation
// initiate the smem base offset for different output tile.
warp_tile_iterator_.set_smem_base_address(smem_base_offset);
shared_load_iterator_.set_smem_base_address(smem_base_offset);
if (!output_op.is_source_needed()) {
compute_source_not_needed_(output_op, destination_iterator, accumulators);
} else {
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
}
}
private:
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
source_iterator.load(source_fragment);
// store to smem
warp_tile_iterator_.store(accumulators);
__syncthreads();
typename SharedLoadIterator::Fragment aligned_accum_fragment;
// load from smem
shared_load_iterator_.load(aligned_accum_fragment);
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_(output_fragment, output_op, aligned_accum_fragment, source_fragment);
// Store to GMEM
destination_iterator.store(output_fragment);
}
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
// store to smem
warp_tile_iterator_.store(accumulators);
__syncthreads();
typename SharedLoadIterator::Fragment aligned_accum_fragment;
// load from smem
shared_load_iterator_.load(aligned_accum_fragment);
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment);
// Store to GMEM
destination_iterator.store(output_fragment);
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment) {
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
OutputAccessType *output_frag_ptr = reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -77,7 +77,6 @@ public:
using OutputTileIterator = OutputTileIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = MatrixShape<0, 0>;

View File

@ -133,7 +133,8 @@ struct EpilogueWithBroadcastOpBase {
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentAccumulator const &AB,
FragmentC const &frag_C,
FragmentC const &frag_C1,
FragmentC const &frag_C2,
FragmentCompute const &V) const {
}
@ -180,9 +181,42 @@ template <
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
bool IsSingleSource = OutputOp_::kIsSingleSource
>
class EpilogueWithBroadcast :
class EpilogueWithBroadcast;
template <
typename Shape_,
typename WarpMmaOperator_,
int PartitionsK,
typename OutputTileIterator_,
typename TensorTileIterator_,
typename ElementVector_,
typename AccumulatorFragmentIterator_,
typename WarpTileIterator_,
typename SharedLoadIterator_,
typename OutputOp_,
typename Padding_,
int FragmentsPerPartition,
int IterationsUnroll
>
class EpilogueWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
false
> :
public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
@ -203,6 +237,7 @@ public:
Padding_,
FragmentsPerPartition>;
static bool const kIsSingleSource = false;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
@ -383,7 +418,687 @@ public:
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord()) {
BroadcastFragment broadcast_fragment;
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
if (!output_op.is_source_needed()) {
compute_source_not_needed_(
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
tensor_iterator);
}
else {
compute_source_needed_(
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
source_iterator1,
source_iterator2,
tensor_iterator);
}
}
private:
CUTLASS_DEVICE
void load_broadcast_fragment_(
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
ElementVector const * broadcast_ptr, ///< Broadcast vector
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
) {
broadcast_fragment.clear();
// If no pointer is supplied, set with all zeros and avoid memory accesses
if (!broadcast_ptr) {
return;
}
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
broadcast_ptr += thread_initial_column;
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
AccessType loaded;
loaded.clear();
if (thread_column_idx < problem_size.column()) {
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
}
ComputeFragmentType cvt = converter(loaded);
frag_ptr[j] = cvt;
thread_column_idx += ThreadMap::Delta::kColumn;
broadcast_ptr += ThreadMap::Delta::kColumn;
}
}
template <class Seq>
struct acc2smem_source_not_needed;
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {
(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
// CUTLASS_PRAGMA_UNROLL
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<
cutlass::make_index_sequence<OutputTileIterator::kIterations /
Base::kFragmentsPerIteration>>::push(iter,
accum_fragment_iterator,
this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
}
else if (kPartitionsK > 1) {
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
apply_output_operator_source_not_needed_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
broadcast_fragment);
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
if (Base::kFragmentsPerIteration > 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
}
template<class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
typename OutputTileIterator::Fragment source_fragment1;
source_fragment1.clear();
typename OutputTileIterator::Fragment source_fragment2;
source_fragment2.clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
source_iterator1.load(source_fragment1);
++source_iterator1;
source_iterator2.load(source_fragment2);
++source_iterator2;
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
apply_output_operator_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
source_fragment1,
source_fragment2,
broadcast_fragment);
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
typename OutputTileIterator::Fragment &frag_Z,
typename TensorTileIterator::Fragment &frag_T,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &frag_AB,
typename OutputTileIterator::Fragment const &frag_C1,
typename OutputTileIterator::Fragment const &frag_C2,
BroadcastFragment const &frag_Broadcast) {
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
OutputAccessType const *frag_C1_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C1);
OutputAccessType const *frag_C2_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C2);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C1_ptr[i],
frag_C2_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
typename OutputTileIterator::Fragment &frag_Z,
typename TensorTileIterator::Fragment &frag_T,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &frag_AB,
BroadcastFragment const &frag_Broadcast) {
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}
};
template <
typename Shape_,
typename WarpMmaOperator_,
int PartitionsK,
typename OutputTileIterator_,
typename TensorTileIterator_,
typename ElementVector_,
typename AccumulatorFragmentIterator_,
typename WarpTileIterator_,
typename SharedLoadIterator_,
typename OutputOp_,
typename Padding_,
int FragmentsPerPartition,
int IterationsUnroll
>
class EpilogueWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
true
> :
public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
static bool const kIsSingleSource = true;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using TensorTileIterator = TensorTileIterator_;
using ElementVector = ElementVector_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Compute data type produced by the output op
using ElementCompute = typename OutputOp::ElementCompute;
/// Compute fragment
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementCompute,
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Data type of additional tensor
using ElementTensor = typename TensorTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
/// Tensor access type
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
/// Shared memory allocation from epilogue base class
using BaseSharedStorage = typename Base::SharedStorage;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = kWarpSize * WarpCount::kCount;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
/// I'm not sure what I meant here.
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
/// Shape of the shared memory allocation for the epilogue
using StorageShape = MatrixShape<
kThreadRows,
Shape::kN
>;
/// Debug printing
CUTLASS_DEVICE
static void print() {
#if 0
printf("BroadcastDetail {\n");
printf(
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
kColumnsPerThread,
kRowsPerThread,
kThreadCount,
kThreadsPerRow,
kThreadRows,
kThreadAccessesPerRow,
StorageShape::kRow,
StorageShape::kColumn,
StorageShape::kCount
);
printf("};\n");
#endif
}
};
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
struct SharedStorage {
union {
BaseSharedStorage base;
};
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
/// Thread index within the threadblock
int thread_idx_;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueWithBroadcast(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
thread_idx_(thread_idx)
{
}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
@ -646,7 +1361,7 @@ private:
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
@ -759,7 +1474,7 @@ private:
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
OutputAccessType const *frag_C_ptr =
OutputAccessType const *frag_C_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C);
AccessTypeBroadcast const *frag_Broadcast_ptr =
@ -770,13 +1485,12 @@ private:
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}

View File

@ -1,847 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include <utility>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h"
#include "cutlass/util/index_sequence.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// This base class is meant to define the concept required of the
/// EpilogueWithBroadcast::OutputOp
template <
typename ElementC_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementZ_,
typename ElementT_,
int ElementsPerAccess,
bool StoreZ = true,
bool StoreT = true
>
struct EpilogueWithBroadcastOpBaseV2 {
using ElementOutput = ElementC_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementZ = ElementZ_;
using ElementT = ElementT_;
static int const kElementsPerAccess = ElementsPerAccess;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
using FragmentT = Array<ElementT, kElementsPerAccess>;
/// If true, the 'Z' tensor is stored
static bool const kStoreZ = StoreZ;
/// If true, the 'T' tensor is stored
static bool const kStoreT = StoreT;
/// Parameters structure - required
struct Params { };
//
// Methods
//
/// Constructor from Params
EpilogueWithBroadcastOpBaseV2(Params const &params_) { }
/// Determine if the source is needed. May return false if
bool is_source_needed() const {
return true;
}
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) { }
/// Applies the operation when is_source_needed() is true
CUTLASS_HOST_DEVICE
void operator()(
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentAccumulator const &AB,
FragmentC const &frag_C1,
FragmentC const &frag_C2,
FragmentCompute const &V) const {
}
/// Applies the operation when is_source_needed() is false
CUTLASS_HOST_DEVICE
void operator()(
FragmentZ &frag_Z,
FragmentT &frag_T,
FragmentAccumulator const &AB,
FragmentCompute const &V) const {
}
};
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator with bias vector broadcast over columns.
///
/// Computes the following:
///
///
/// Z, T = OutputOp(AB, C, Broadcast)
///
/// if (ElementwiseOp::kStoreZ) {
/// store(converted_u);
/// }
///
/// if (ElementwiseOp::kStoreT) {
/// store(v);
/// }
///
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
typename ElementVector_, ///< Pointer to broadcast vector
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
>
class EpilogueWithBroadcastV2 :
public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using TensorTileIterator = TensorTileIterator_;
using ElementVector = ElementVector_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Compute data type produced by the output op
using ElementCompute = typename OutputOp::ElementCompute;
/// Compute fragment
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
/// Thread map used by output tile iterators
using ThreadMap = typename OutputTileIterator::ThreadMap;
/// Fragment object used to store the broadcast values
using BroadcastFragment = Array<
ElementCompute,
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Data type of additional tensor
using ElementTensor = typename TensorTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Array type used by output functor
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
/// Tensor access type
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
/// Shared memory allocation from epilogue base class
using BaseSharedStorage = typename Base::SharedStorage;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
/// Used for the broadcast
struct BroadcastDetail {
/// Number of threads per warp
static int const kWarpSize = 32;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
/// Number of distinct scalar column indices handled by each thread
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
/// Number of distinct scalar row indices handled by each thread
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
/// Number of threads per threadblock
static int const kThreadCount = kWarpSize * WarpCount::kCount;
/// Number of distinct threads per row of output tile
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
static int const kThreadRows = kThreadCount / kThreadsPerRow;
/// I'm not sure what I meant here.
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
/// Shape of the shared memory allocation for the epilogue
using StorageShape = MatrixShape<
kThreadRows,
Shape::kN
>;
/// Debug printing
CUTLASS_DEVICE
static void print() {
#if 0
printf("BroadcastDetail {\n");
printf(
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
kColumnsPerThread,
kRowsPerThread,
kThreadCount,
kThreadsPerRow,
kThreadRows,
kThreadAccessesPerRow,
StorageShape::kRow,
StorageShape::kColumn,
StorageShape::kCount
);
printf("};\n");
#endif
}
};
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
struct SharedStorage {
union {
BaseSharedStorage base;
};
CUTLASS_HOST_DEVICE
SharedStorage() { }
};
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
/// Thread index within the threadblock
int thread_idx_;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueWithBroadcastV2(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
thread_idx_(thread_idx)
{
}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator1, ///< Tile iterator for source accumulator matrix
OutputTileIterator source_iterator2, ///< Tile iterator for source accumulator matrix
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord()) {
BroadcastFragment broadcast_fragment;
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
if (!output_op.is_source_needed()) {
compute_source_not_needed_(
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
tensor_iterator);
}
else {
compute_source_needed_(
output_op,
broadcast_fragment,
destination_iterator,
accumulators,
source_iterator1,
source_iterator2,
tensor_iterator);
}
}
private:
CUTLASS_DEVICE
void load_broadcast_fragment_(
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
ElementVector const * broadcast_ptr, ///< Broadcast vector
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
) {
broadcast_fragment.clear();
// If no pointer is supplied, set with all zeros and avoid memory accesses
if (!broadcast_ptr) {
return;
}
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
broadcast_ptr += thread_initial_column;
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
AccessType loaded;
loaded.clear();
if (thread_column_idx < problem_size.column()) {
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
}
ComputeFragmentType cvt = converter(loaded);
frag_ptr[j] = cvt;
thread_column_idx += ThreadMap::Delta::kColumn;
broadcast_ptr += ThreadMap::Delta::kColumn;
}
}
template <class Seq>
struct acc2smem_source_not_needed;
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {
(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
// CUTLASS_PRAGMA_UNROLL
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<
cutlass::make_index_sequence<OutputTileIterator::kIterations /
Base::kFragmentsPerIteration>>::push(iter,
accum_fragment_iterator,
this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
}
else if (kPartitionsK > 1) {
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
apply_output_operator_source_not_needed_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
broadcast_fragment);
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
if (Base::kFragmentsPerIteration > 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
}
template<class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator1, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator2, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
) {
typename OutputTileIterator::Fragment source_fragment1;
source_fragment1.clear();
typename OutputTileIterator::Fragment source_fragment2;
source_fragment2.clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
source_iterator1.load(source_fragment1);
++source_iterator1;
if (source_iterator2.enabled()) {
source_iterator2.load(source_fragment2);
++source_iterator2;
}
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
apply_output_operator_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
source_fragment1,
source_fragment2,
broadcast_fragment,
source_iterator2.enabled());
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
typename OutputTileIterator::Fragment &frag_Z,
typename TensorTileIterator::Fragment &frag_T,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &frag_AB,
typename OutputTileIterator::Fragment const &frag_C1,
typename OutputTileIterator::Fragment const &frag_C2,
BroadcastFragment const &frag_Broadcast,
bool frag_C2_enabled) {
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
OutputAccessType const *frag_C1_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C1);
OutputAccessType const *frag_C2_ptr =
reinterpret_cast<OutputAccessType const *>(&frag_C2);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
if (frag_C2_enabled) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C1_ptr[i],
frag_C2_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
} else {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_C1_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
typename OutputTileIterator::Fragment &frag_Z,
typename TensorTileIterator::Fragment &frag_T,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &frag_AB,
BroadcastFragment const &frag_Broadcast) {
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
AccumulatorAccessType const *frag_AB_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
AccessTypeBroadcast const *frag_Broadcast_ptr =
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
output_op(
frag_Z_ptr[i],
frag_T_ptr[i],
frag_AB_ptr[i],
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -124,6 +124,8 @@ public:
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
static bool const kIsSingleSource = true;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
@ -294,7 +296,7 @@ public:
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
ElementVector * reduction_output_ptr, ///< Reduction output vector
ElementVector * reduction_output_ptr, ///< Reduction output vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix

View File

@ -51,7 +51,7 @@
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////
@ -78,8 +78,21 @@ template <
typename OutputOp_,
/// Number of interleaved k
int InterleavedK>
class InterleavedEpilogue {
public:
class InterleavedEpilogue :
public EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>
{
public:
using BaseStreamK = EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
@ -90,6 +103,9 @@ class InterleavedEpilogue {
/// The complete warp-level accumulator tile
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
/// Fragment type used by the accumulator tile's fragment iterator
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
/// Accumulator element
using ElementAccumulator = typename AccumulatorTile::Element;
@ -122,7 +138,8 @@ class InterleavedEpilogue {
gemm::GemmShape<Shape::kM / WarpMmaOperator::Shape::kM,
Shape::kN / WarpMmaOperator::Shape::kN, kPartitionsK>;
public:
public:
static_assert(OutputTileIterator::kElementsPerAccess,
"This must not be zero.");
@ -134,15 +151,58 @@ class InterleavedEpilogue {
struct SharedStorage {};
public:
public:
/// Constructor
CUTLASS_DEVICE
InterleavedEpilogue(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
) {}
int lane_idx) ///< Id of thread within warp
:
BaseStreamK(thread_idx)
{}
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
/// performing epilogue computations, writing to output
CUTLASS_DEVICE
void reduce(
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
ElementAccumulator *element_workspace,
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
{
// Redcuce peer accumulator fragments into one fragment
AccumulatorFragment accum_fragment;
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
// Source-fragment data (zero-initialized for scenarios where the
// output operator allows us to skip loading it from global input)
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
if (output_op.is_source_needed())
{
source_iterator += reduce_fragment_idx;
source_iterator.load(source_fragment);
}
// Compute the output result
typename OutputTileIterator::Fragment output_fragment;
// Apply the output operator
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
// Store the final result
destination_iterator += reduce_fragment_idx;
destination_iterator.store(output_fragment);
}
/// Streams the result to global memory
CUTLASS_DEVICE
@ -194,7 +254,7 @@ class InterleavedEpilogue {
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(output_op, output_fragment, accum_fragment);
apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment);
//
// Store the final result
@ -257,7 +317,7 @@ class InterleavedEpilogue {
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_needed_(output_op, output_fragment, accum_fragment, source_fragment);
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
//
// Store the final result
@ -269,15 +329,16 @@ class InterleavedEpilogue {
}
}
private:
protected:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_needed_(
OutputOp const &output_op, ///< Output operator
typename OutputTileIterator::Fragment &output_fragment,
typename AccumulatorFragmentIterator::Fragment const
&aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment) {
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
@ -300,11 +361,11 @@ class InterleavedEpilogue {
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
OutputOp const &output_op, ///< Output operator
typename OutputTileIterator::Fragment &output_fragment,
typename AccumulatorFragmentIterator::Fragment const
&aligned_accum_fragment) {
void apply_output_operator_source_not_needed(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);

View File

@ -680,6 +680,9 @@ public:
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
store_byte_pointer_ += params_.advance_tile;
thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow
* ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile;
}
}
}
@ -687,6 +690,60 @@ public:
return *this;
}
/// Advances a number of positions to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIterator &operator+=(int increment)
{
// Row
state_[0] += increment;
int increment_row = state_[0] / ThreadMap::Count::kRow;
state_[0] = state_[0] % ThreadMap::Count::kRow;
byte_pointer_ += (params_.advance_row * increment);
store_byte_pointer_ += (params_.advance_row * increment);
thread_start_row_ += (ThreadMap::Shape::kRow * increment);
// Group
state_[1] += increment_row;
int increment_group = state_[1] / ThreadMap::Count::kGroup;
state_[1] = state_[1] % ThreadMap::Count::kGroup;
byte_pointer_ += (params_.advance_group * increment_row);
store_byte_pointer_ += (params_.advance_group * increment_row);
thread_start_row_ +=
(ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow *
ThreadMap::Count::kRow *
increment_row;
// Cluster
state_[2] += increment_group;
int increment_cluster = state_[2] / ThreadMap::Count::kCluster;
state_[2] = state_[2] % ThreadMap::Count::kCluster;
byte_pointer_ += (params_.advance_cluster * increment_group);
store_byte_pointer_ += (params_.advance_cluster * increment_group);
thread_start_row_ +=
ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup *
ThreadMap::Count::kRow *
ThreadMap::Shape::kRow *
increment_group;
// Tile
byte_pointer_ += (params_.advance_tile * increment_cluster);
store_byte_pointer_ += (params_.advance_tile * increment_cluster);
thread_start_row_ +=
ThreadMap::Shape::kGroup *
ThreadMap::Shape::kRow *
ThreadMap::Shape::kCluster *
ThreadMap::Shape::kTile *
increment_cluster;
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() {
mask_.clear();
@ -944,6 +1001,23 @@ public:
return *this;
}
/// Advances a number of positions to load or store
CUTLASS_HOST_DEVICE
InterleavedPredicatedTileIterator &operator+=(int increment)
{
// Contiguous
iteration_contiguous_ += increment;
int increment_strided = iteration_contiguous_ / ThreadMap::Iterations::kContiguous;
iteration_contiguous_ = iteration_contiguous_ % ThreadMap::Iterations::kContiguous;
byte_pointer_ += (params_.advance_row * increment);
// Strided
iteration_strided_ += increment_strided;
byte_pointer_ += (params_.advance_column * increment_strided);
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() {
mask_.clear();

View File

@ -0,0 +1,445 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/permute.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
#include "cutlass/conv/conv2d_problem_size.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load and store output tile from global memory in epilogue.
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
///
template <
typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap)
typename Element_, ///< Element data type
typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>,
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>
>
class PredicatedTileIteratorDirectConv {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using ThreadOutputShape = ThreadOutputShape_;
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
using ConvProblemSize = typename cutlass::conv::Conv2dProblemSize;
/// Fragment object
using Fragment = Array<Element, ThreadMap::Iterations::kCount * kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, kElementsPerAccess>;
static int const kLoadsPerAccess = AccessType::kElements / AccessType::kElements;
using ThreadTileCount = MatrixShape<
ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
ThreadBlockOutputShape::kW / ThreadOutputShape::kW
>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorDirect2dConvParams {
using Base = PredicatedTileIteratorDirect2dConvParams;
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Layout const &layout, cutlass::conv::Conv2dProblemSize const &problem_size):
PredicatedTileIteratorDirect2dConvParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
problem_size,
{ThreadBlockOutputShape::kH, ThreadBlockOutputShape::kW}
)
{ }
CUTLASS_HOST_DEVICE
Params(Base const &base) :
Base(base) { }
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kContiguous;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() {
enable();
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = false;
}
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = true;
}
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorDirect2dConvParams params_;
/// Byte-level pointer
uint8_t *byte_pointer_;
///
Element *pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Initial thread ouput location
int thread_start_n_, thread_start_p_, thread_start_q_;
/// Current threadblock tile index
int tile_index_;
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(sizeof(PredicatedTileIteratorDirect2dConvParams::stride) == 8, "Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorDirectConv(
PredicatedTileIteratorDirect2dConvParams const & params,
Element *pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord()
):
params_(params), pointer_(pointer)
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
extent_row_ = extent.row();
extent_column_ = extent.column();
// stride dim (PQ)
thread_start_row_ = thread_offset.column();
// contiguous dim (Channels)
thread_start_column_ = threadblock_offset.column() + thread_offset.row();
tile_index_ = threadblock_offset.row();
set_tile_index(0);
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void set_tile_index(const int index) {
int residual;
params_.pq_divmod(thread_start_n_, residual, tile_index_ + index);
params_.q_divmod(thread_start_p_, thread_start_q_, residual);
// Compute the base output coord of ThreadBlock
thread_start_p_ *= ThreadBlockOutputShape::kH;
thread_start_q_ *= ThreadBlockOutputShape::kW;
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
mask_.predicates[c] = ((thread_start_column_
+ c * ThreadMap::Delta::kContiguous) < extent_column_);
}
// Null pointer performs no accesses
if (!pointer_) {
mask_.clear();
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c;
int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided;
int p = current_row / ThreadBlockOutputShape::kW;
int q = current_row % ThreadBlockOutputShape::kW;
int current_p = thread_start_p_ + p;
int current_q = thread_start_q_ + q;
bool row_guard = (current_p) < params_.P && (current_q) < params_.Q &&
(thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided;
int output_row_offset =
thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q;
uint8_t *byte_pointer =
reinterpret_cast<uint8_t *>(pointer_) +
LongIndex(output_row_offset) * LongIndex(params_.stride) +
LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) *
sizeof(AccessType) / kElementsPerAccess;
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
bool guard = row_guard && mask_.predicates[c];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard);
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) const {
load_with_byte_offset(frag, 0);
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c;
int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided;
int p = current_row / ThreadBlockOutputShape::kW;
int q = current_row % ThreadBlockOutputShape::kW;
int current_p = thread_start_p_ + p;
int current_q = thread_start_q_ + q;
bool row_guard = (current_p) < params_.P && (current_q) < params_.Q &&
(thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided;
int output_row_offset =
thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q;
uint8_t *byte_pointer =
reinterpret_cast<uint8_t *>(pointer_) +
LongIndex(output_row_offset) * LongIndex(params_.stride) +
LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) *
sizeof(AccessType) / kElementsPerAccess;
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
bool guard = row_guard && mask_.predicates[c];
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard);
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) const {
store_with_byte_offset(frag, 0);
}
CUTLASS_DEVICE
MatrixCoord thread_start() const {
return MatrixCoord(thread_start_row_, thread_start_column_);
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_row() const {
return thread_start_row_;
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_column() const {
return thread_start_column_;
}
/// Extent of the matrix in rows
CUTLASS_DEVICE
Index extent_row() const {
return extent_row_;
}
/// Extent of the matrix in columns
CUTLASS_DEVICE
Index extent_column() const {
return extent_column_;
}
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorDirectConv &operator++() {
// do nothing
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() {
mask_.clear();
}
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() {
mask_.enable();
}
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask &mask) const {
mask = mask_;
}
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const &mask) {
mask_ = mask;
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -35,9 +35,12 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/conv/conv2d_problem_size.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@ -245,6 +248,87 @@ struct PredicatedTileIteratorParams {
};
///////////////////////////////////////////////////////////////////////////////
//
// Parameters struct for PredicatedTileIteratorDirect2dConv
//
struct PredicatedTileIteratorDirect2dConvParams{
using Index = int32_t;
using LongIndex = int64_t;
//
// Data members
//
FastDivmod pq_divmod;
FastDivmod q_divmod;
LongIndex stride;
LongIndex stride_n;
LongIndex stride_p;
int N;
int P;
int Q;
//
// Methods
//
CUTLASS_HOST_DEVICE
Status initialize(LongIndex stride_,
cutlass::conv::Conv2dProblemSize const &problem_size,
MatrixCoord threadblock_output_shape) {
stride = stride_; // The stride per row of output tensor (bytes)
stride_n = problem_size.P * problem_size.Q;
stride_p = problem_size.Q ;
N = problem_size.N;
P = problem_size.P;
Q = problem_size.Q;
// Fastdivmod for output O, P, Q
if(threadblock_output_shape.row() != 0 && threadblock_output_shape.column() !=0 ){
int tiles_p =
(problem_size.P + (threadblock_output_shape.row() - 1)) / (threadblock_output_shape.row());
int tiles_q = (problem_size.Q + (threadblock_output_shape.column() - 1)) /
(threadblock_output_shape.column());
pq_divmod = FastDivmod(tiles_p * tiles_q);
q_divmod = FastDivmod(tiles_q);
}
return Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Status initialize(
Index stride_,
cutlass::conv::Conv2dProblemSize const &problem_size = cutlass::conv::Conv2dProblemSize(),
MatrixCoord threadblock_output_shape = MatrixCoord()) {
return initialize(LongIndex(stride_), problem_size, threadblock_output_shape);
}
CUTLASS_HOST_DEVICE
PredicatedTileIteratorDirect2dConvParams() { initialize(LongIndex(0)); }
CUTLASS_HOST_DEVICE
PredicatedTileIteratorDirect2dConvParams(Index stride,
cutlass::conv::Conv2dProblemSize const &problem_size,
MatrixCoord threadblock_output_shape) {
initialize(stride, problem_size, threadblock_output_shape);
}
CUTLASS_HOST_DEVICE
PredicatedTileIteratorDirect2dConvParams(LongIndex stride,
cutlass::conv::Conv2dProblemSize const &problem_size,
MatrixCoord threadblock_output_shape) {
initialize(stride, problem_size, threadblock_output_shape);
}
};
///////////////////////////////////////////////////////////////////////////////
// InterleavedPredicatedTileIterator
///////////////////////////////////////////////////////////////////////////////

View File

@ -201,6 +201,11 @@ public:
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void set_smem_base_address(Index address) {
}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment &frag) const {

View File

@ -234,6 +234,10 @@ public:
}
}
/// Set base smem address
CUTLASS_DEVICE
void set_smem_base_address(Index address) {}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment &frag) const {
@ -395,6 +399,10 @@ public:
}
}
/// Set base smem address
CUTLASS_DEVICE
void set_smem_base_address(Index address) {}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment &frag) {
@ -556,6 +564,10 @@ public:
}
}
/// Set base smem address
CUTLASS_DEVICE
void set_smem_base_address(Index address) {}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment &frag) {

View File

@ -0,0 +1,194 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading.
When the fragment is loaded into registers, it matches the row-major thread map assumed by
the predicated tile iterator writing to global memory.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap)
typename Element_, ///< Element data type
int MaxAlignment = ThreadMap_::kElementsPerAccess *sizeof_bits<Element_>::value / 8>
class SharedLoadIteratorPitchLiner {
public:
using ThreadMap = ThreadMap_;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kMinAlignment =
ThreadMap_::kElementsPerAccess * sizeof_bits<Element_>::value / 8;
static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment);
static int const kThreads = ThreadMap::kThreads;
/// Fragment object
using Fragment = Array<Element, ThreadMap::Iterations::kCount * kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, kElementsPerAccess, kAlignment>;
/// Vector type used for SMEM loads
using LoadType =
AlignedArray<Element,
const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
const_min(16, kAlignment)>;
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
private:
//
// Data members
//
/// Byte-level pointer
uint8_t *byte_pointer_;
/// Stride along adjacent rows
int stride_;
/// Base address offset
Index base_smem_address_;
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorPitchLiner(TensorRef ref, int thread_idx)
: byte_pointer_(reinterpret_cast<uint8_t *>(ref.data())),
stride_((ref.stride(0) * sizeof_bits<Element>::value) / 8),
base_smem_address_(0) {
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
// Initialize pointer
// thread_offset.row() is contiguous dim
// thread_offset.column() is stride dim
byte_pointer_ += thread_offset.row() * sizeof(AccessType) / kElementsPerAccess+
thread_offset.column() * stride_ ;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const &offset) {
byte_pointer_ +=
offset.row() * ThreadMap::StorageShape::kContiguous * sizeof(AccessType) / kElementsPerAccess +
offset.column() * ThreadMap::StorageShape::kStrided * stride_;
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
uint8_t const *byte_pointer =
byte_pointer_ + s * ThreadMap::Delta::kStrided * stride_ +
c * ThreadMap::Delta::kContiguous * ThreadMap::kElementsPerAccess *
sizeof_bits<Element>::value / 8 +
pointer_offset * sizeof_bits<Element>::value / 8 + base_smem_address_;
int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c;
LoadType *frag_ptr = reinterpret_cast<LoadType *>(&frag);
LoadType const *memory_pointer = reinterpret_cast<LoadType const *>(byte_pointer);
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kLoadsPerAccess; ++v) {
frag_ptr[frag_base_idx * kLoadsPerAccess + v] = memory_pointer[v];
}
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void set_smem_base_address(Index address) { base_smem_address_ = address; }
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -240,12 +240,301 @@ public:
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Template for reading and writing tiles of accumulators to shared memory
template <typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
typename Operator_, ///< matrix multiply operation (concept: arch::Mma)
typename Element_, ///< data type of element to be written
typename Layout_, ///< target shared memory layout
typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy)
>
class TileIteratorSimtDirectConv {
public:
using WarpShape = WarpShape_;
using Operator = Operator_;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>; ///< Tensor Reference object
using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor
using Index = typename TensorRef::Index;
using LongIndex = typename TensorRef::LongIndex;
using Policy = SimtPolicy<WarpShape, Operator, Layout, MmaSimtPolicy_>;
/// Shape of the tile in memory
using Shape = MatrixShape<Policy::kRowsPerIteration, WarpShape::kN>;
/// This is the fragment size produced by one access of the iterator.
using Fragment = Array<typename Operator::ElementC, Policy::kElementsPerIteration>;
/// This is the complete warp-level accumulator tile.
using AccumulatorTile = Array<typename Operator::ElementC, Policy::kAccumulatorElementCount>;
/// Number of times this iterator can be incremented
static int const kIterations = Policy::kIterations;
/// Padding quantity
using Padding = MatrixShape<0,
0
>;
private:
/// Storage type for accessing memory
using AccessType = AlignedArray<
Element,
Policy::kElementsPerAccess
>;
//
// Data members
//
/// Internal pointer to memory
AccessType *pointer_;
/// Internal layout object
Layout layout_;
/// Base smem offset;
Index base_smem_address_;
public:
/// Default constructor
CUTLASS_HOST_DEVICE
TileIteratorSimtDirectConv() : pointer_(nullptr) {}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
TileIteratorSimtDirectConv(
TensorRef const &ref,
unsigned lane_id
):
pointer_(reinterpret_cast<AccessType *>(ref.data())),
layout_(ref.stride()[0] / AccessType::kElements) {
auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout();
MatrixCoord lane_offset = lane_layout.inverse(lane_id);
pointer_ += layout_({
lane_offset.row(),
lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements)
});
}
/// Adds a pointer offset
CUTLASS_HOST_DEVICE
TileIteratorSimtDirectConv & add_pointer_offset(Index pointer_offset) {
pointer_ += pointer_offset / AccessType::kElements;
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_HOST_DEVICE
TileIteratorSimtDirectConv & add_tile_offset(TensorCoord const &tile_offset) {
pointer_ += layout_({
tile_offset.row() * Shape::kRow,
(tile_offset.column() * Shape::kColumn / int(AccessType::kElements))
});
return *this;
}
///< advances in units of whole tiles along the logical coordinate space of the tensor
CUTLASS_HOST_DEVICE
TileIteratorSimtDirectConv & operator+=(TensorCoord const &tile_offset) {
add_tile_offset(tile_offset);
return *this;
}
/// Store
CUTLASS_HOST_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
// original vector stores
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
AccessType * load_pointer_ = reinterpret_cast<AccessType *>(reinterpret_cast<uint8_t *>(pointer_) + base_smem_address_);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
load_pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n];
}
}
/// Store
CUTLASS_HOST_DEVICE
void store(Fragment const &frag) {
store_with_pointer_offset(frag, 0);
}
/// Load
CUTLASS_HOST_DEVICE
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)];
}
}
/// Load
CUTLASS_HOST_DEVICE
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address){
base_smem_address_ = address;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Template for reading and writing tiles of accumulators to shared memory
template <typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
typename ThreadOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
typename ThreadBlockOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
typename Operator_, ///< matrix multi ply operation (concept: arch::Mma)
typename Element_, ///< data type of element to be written
typename Layout_, ///< target shared memory layout
typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy)
>
class TileIteratorSimtDirect2dConv {
public:
using WarpShape = WarpShape_;
using ThreadOutputShape = ThreadOutputShape_;
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
using Operator = Operator_;
using Element = Element_;
using Layout = layout::RowMajor;
using MmaSimtPolicy = MmaSimtPolicy_;
using TensorRef = TensorRef<Element, Layout>; ///< Tensor Reference object
using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor
using Index = typename TensorRef::Index;
using LongIndex = typename TensorRef::LongIndex;
// Thread-level shape of a fragment
using ThreadShape = MatrixShape<ThreadOutputShape::kNHW, ThreadOutputShape::kC>;
static_assert(!(ThreadShape::kColumn % MmaSimtPolicy::LaneMmaShape::kN),
"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
using ThreadTileCount = MatrixShape<ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
ThreadBlockOutputShape::kW / ThreadOutputShape::kW>;
using Iterations =
MatrixShape<ThreadShape::kRow, ThreadShape::kColumn / MmaSimtPolicy::LaneMmaShape::kN>;
/// This is the complete warp-level accumulator tile.
using AccumulatorTile = typename Operator::FragmentC;
/// This is the fragment size produced by one access of the iterator.
using Fragment = AccumulatorTile;
/// Padding quantity
using Padding = MatrixShape<0, 0>;
private:
// Storage type for accessing memory
using AccessType = AlignedArray<Element, MmaSimtPolicy::LaneMmaShape::kN>;
//
// Data members
//
/// Internal pointer to memory
AccessType *pointer_;
/// Internal layout object
Layout layout_;
/// Base smem offset;
Index base_smem_address_;
public:
/// Default constructor
CUTLASS_HOST_DEVICE
TileIteratorSimtDirect2dConv() : pointer_(nullptr) {}
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
TileIteratorSimtDirect2dConv(TensorRef const &ref, unsigned thread_id, unsigned lane_id)
: pointer_(reinterpret_cast<AccessType *>(ref.data())),
layout_(ref.stride()[0] / AccessType::kElements) {
auto lane_layout = MmaSimtPolicy::get_lane_layout();
MatrixCoord lane_offset = lane_layout.inverse(lane_id);
// Get base HW offset of current threads
const int threadgroup = thread_id / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC);
const int base_p = (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH;
const int base_q = (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW;
const int row_offset = base_p * ThreadBlockOutputShape::kW + base_q;
pointer_ += layout_(
{row_offset,
lane_offset.column() * MmaSimtPolicy::LaneMmaShape::kN / int(AccessType::kElements)});
}
/// Adds a pointer offset
CUTLASS_HOST_DEVICE
TileIteratorSimtDirect2dConv &add_pointer_offset(Index pointer_offset) {
pointer_ += pointer_offset / AccessType::kElements;
return *this;
}
/// Store
CUTLASS_HOST_DEVICE
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
AccessType *storer_pointer_ =
reinterpret_cast<AccessType *>(reinterpret_cast<uint8_t *>(pointer_) + base_smem_address_);
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < ThreadOutputShape::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < ThreadOutputShape::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < Iterations::kColumn; ++col) {
int offset = (w + h * ThreadBlockOutputShape::kW) *
(ThreadBlockOutputShape::kC / AccessType::kElements) +
col;
storer_pointer_[offset + pointer_offset / int(AccessType::kElements)] =
frag_ptr[w + h * ThreadOutputShape::kW + col];
}
}
}
}
/// Store
CUTLASS_HOST_DEVICE
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) { base_smem_address_ = address; }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Template for reading and writing tiles of accumulators to shared memory
template <
typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
@ -482,6 +771,10 @@ public:
return add_tile_offset({1, 0});
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};

View File

@ -228,6 +228,11 @@ public:
TileIteratorTensorOp & operator++() {
return add_tile_offset({1, 0});
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -420,6 +425,11 @@ public:
TileIteratorTensorOp & operator++() {
return add_tile_offset({0, 1});
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
@ -645,6 +655,11 @@ public:
TileIteratorTensorOpCanonical & operator++() {
return add_tile_offset({1, 0});
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -304,6 +304,11 @@ public:
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -506,6 +511,11 @@ public:
void store(Fragment const &frag) {
store_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -697,6 +707,11 @@ public:
void store(Fragment const &frag) {
store_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -242,6 +242,11 @@ public:
void load(Fragment const &frag) {
load_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -419,6 +424,11 @@ public:
void load(Fragment const &frag) {
load_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -207,6 +207,12 @@ public:
void load(Fragment &frag) const {
load_with_pointer_offset(frag, 0);
}
/// Set smem base address
CUTLASS_HOST_DEVICE
void set_smem_base_address(Index address) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -280,6 +280,36 @@ struct FastDivmod {
unsigned int multiplier;
unsigned int shift_right;
/// Find quotient and remainder using device-side intrinsics
CUTLASS_HOST_DEVICE
void fast_divmod(int& quotient, int& remainder, int dividend) const {
#if defined(__CUDA_ARCH__)
// Use IMUL.HI if divisor != 1, else simply copy the source.
quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend;
#else
quotient = int((divisor != 1) ? int(((int64_t)dividend * multiplier) >> 32) >> shift_right : dividend);
#endif
// The remainder.
remainder = dividend - (quotient * divisor);
}
/// For long int input
CUTLASS_HOST_DEVICE
void fast_divmod(int& quotient, int64_t& remainder, int64_t dividend) const {
#if defined(__CUDA_ARCH__)
// Use IMUL.HI if divisor != 1, else simply copy the source.
quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend;
#else
quotient = int((divisor != 1) ? ((dividend * multiplier) >> 32) >> shift_right : dividend);
#endif
// The remainder.
remainder = dividend - (quotient * divisor);
}
/// Construct the FastDivmod object, in host code ideally.
///
/// This precomputes some values based on the divisor and is computationally expensive.
@ -288,17 +318,35 @@ struct FastDivmod {
FastDivmod(): divisor(0), multiplier(0), shift_right(0) { }
CUTLASS_HOST_DEVICE
FastDivmod(int divisor_): divisor(divisor_) {
find_divisor(multiplier, shift_right, divisor);
FastDivmod(int divisor): divisor(divisor) {
if (divisor != 1) {
unsigned int p = 31 + find_log2(divisor);
unsigned m = unsigned(((1ull << p) + unsigned(divisor) - 1) / unsigned(divisor));
multiplier = m;
shift_right = p - 32;
} else {
multiplier = 0;
shift_right = 0;
}
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
void operator()(int &quotient, int &remainder, int dividend) const {
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
fast_divmod(quotient, remainder, dividend);
}
/// Computes integer division using precomputed values. This is computationally
/// inexpensive.
CUTLASS_HOST_DEVICE
int div(int dividend) const {
int quotient, remainder;
fast_divmod(quotient, remainder, dividend);
return quotient;
}
/// Computes integer division and modulus using precomputed values. This is computationally
/// inexpensive.
@ -307,7 +355,7 @@ struct FastDivmod {
CUTLASS_HOST_DEVICE
int divmod(int &remainder, int dividend) const {
int quotient;
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
fast_divmod(quotient, remainder, dividend);
return quotient;
}
@ -315,7 +363,7 @@ struct FastDivmod {
/// inexpensive.
CUTLASS_HOST_DEVICE
void operator()(int &quotient, int64_t &remainder, int64_t dividend) const {
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
fast_divmod(quotient, remainder, dividend);
}
/// Computes integer division and modulus using precomputed values. This is computationally
@ -323,9 +371,14 @@ struct FastDivmod {
CUTLASS_HOST_DEVICE
int divmod(int64_t &remainder, int64_t dividend) const {
int quotient;
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
fast_divmod(quotient, remainder, dividend);
return quotient;
}
/// Returns the divisor when cast to integer
CUTLASS_HOST_DEVICE
operator int() const { return divisor; }
};
/////////////////////////////////////////////////////////////////////////////////////////////////

1213
include/cutlass/float8.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,65 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Defines categories for floating point numbers for use in NVRTC-compiled code
*/
#pragma once
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
// All floating-point numbers can be put in one of these categories.
enum {
FP_NAN =
# define FP_NAN 0
FP_NAN,
FP_INFINITE =
# define FP_INFINITE 1
FP_INFINITE,
FP_ZERO =
# define FP_ZERO 2
FP_ZERO,
FP_SUBNORMAL =
# define FP_SUBNORMAL 3
FP_SUBNORMAL,
FP_NORMAL =
# define FP_NORMAL 4
FP_NORMAL
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@ -211,7 +211,7 @@ public:
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
@ -349,7 +349,7 @@ public:
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Workspace

View File

@ -765,6 +765,52 @@ struct DefaultGemmConfiguration<
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <typename ElementC,
typename ElementAccumulator>
struct DefaultGemmConfiguration<arch::OpClassTensorOp, arch::Sm90, double,
double, ElementC, ElementAccumulator> {
static int const kAlignmentA = 1;
static int const kAlignmentB = 1;
using ThreadblockShape = GemmShape<128, 256, 64>;
using WarpShape = GemmShape<64, 64, 64>;
using InstructionShape = GemmShape<16, 8, 4>;
static int const kStages = 3;
using EpilogueOutputOp = epilogue::thread::LinearCombination<
ElementC, 128 / sizeof_bits<ElementC>::value, ElementAccumulator,
ElementAccumulator>;
using Operator = arch::OpMultiplyAdd;
};
template <>
struct DefaultGemmConfiguration<
arch::OpClassTensorOp,
arch::Sm90,
complex<double>,
complex<double>,
complex<double>,
complex<double>
> {
static int const kAlignmentA = 1;
static int const kAlignmentB = 1;
using ThreadblockShape = GemmShape<64, 64, 16>;
using WarpShape = GemmShape<32, 32, 16>;
using InstructionShape = GemmShape<16, 8, 4>;
static int const kStages = 3;
using EpilogueOutputOp = epilogue::thread::LinearCombination<
complex<double>, 1, complex<double>,
complex<double>>;
using Operator = arch::OpMultiplyAddComplex;
};
} // namespace device
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,848 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a Block-Ell sparse gemm kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/ell_gemm.h"
#include "cutlass/gemm/kernel/default_ell_gemm.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*! Blocked-Ell sparse gemm device-level operator. This is an interface to efficient CUTLASS
Blocked-Ell kernels that may be invoked from host code.
The contributions of this class are:
1. At compile time, it maps data types and high-level structural parameters onto
specific CUTLASS components.
2. At runtime, it maps logical arguments to Blocked-Ell problems to kernel parameters.
3. At runtime, it launches kernels on the device.
Example of a CUTLASS EllGemm operator is as follows:
//
// Instantiate the CUTLASS EllGemm operator.
//
cutlass::gemm::device::EllGemm<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
cutlass::half_t, 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
4, // Stages
128 / cutlass::sizeof_bits<cutlass::half_t>::value, // Alignment A
128 / cutlass::sizeof_bits<cutlass::half_t>::value // Alignment B
> ellgemm_op;
//
// Launch the EllGemm operation on the device
//
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format:
a_rows - Rows in the sparse matrix.
a_cols - Colums in the sparse matrix.
BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in
consecutive blocks, whose size is (a_rows * a_ell_num_columns)
ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is
(a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize)
a_ell_blocksize - Size of the ELL-Blocks.
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
B - Input dense matrix whose size is (a_cols * n)
C/D - Output dense matrix whose size is (a_rows * n)
cutlass::Status status = ellgemm_op({
{a_rows, n, a_cols}, // GemmCoord problem_size
{BlockedEllA, lda}, // TensorRef<cutlass::half_t, layout::RowMajor> ref_BlockedEllA
{B, ldb}, // TensorRef<cutlass::half_t, layout::ColumnMajor> ref_B,
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
ell_idx, // Blocked-ELL Column indices or ellColInd matrix (const int*)
a_ell_num_columns, // Columns in the Blocked-Ellpack (ellValue) matrix (int)
a_ell_blocksize, // Size of the ELL-Blocks (int)
a_ell_base, // Base index of ellColInd (int) - Zero or One
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
});
A simplified view of the template is listed below.
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
/// Supports split-K with serial reduction
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
class EllGemm;
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassTensorOp,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm80,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ =
typename threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Sparse matrix is A or not
bool IsASparse = true
>
class EllGemm {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static bool const kIsASparse = IsASparse;
/// Define the kernel
using GemmKernel = typename kernel::DefaultEllGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator,
kIsASparse
>::GemmKernel;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A;
TensorRef<ElementB const, LayoutB> ref_B;
TensorRef<ElementC const, LayoutC> ref_C;
TensorRef<ElementC, LayoutC> ref_D;
const int* ell_idx;
int ell_ncol;
int ell_blocksize;
int ell_base_idx;
typename EpilogueOutputOp::Params epilogue;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementC const, LayoutC> ref_C_,
TensorRef<ElementC, LayoutC> ref_D_,
const int* ell_idx_,
int ell_ncol_,
int ell_blocksize_,
int ell_base_idx_,
typename EpilogueOutputOp::Params epilogue_ =
typename EpilogueOutputOp::Params(),
int split_k_slices = 1
):
problem_size(problem_size_),
ref_A(ref_A_),
ref_B(ref_B_),
ref_C(ref_C_),
ref_D(ref_D_),
ell_idx(ell_idx_),
ell_ncol(ell_ncol_),
ell_blocksize(ell_blocksize_),
ell_base_idx(ell_base_idx_),
epilogue(epilogue_),
split_k_slices(split_k_slices) {
}
};
private:
/// Kernel parameters object
typename GemmKernel::Params params_;
public:
/// Constructs the GEMM.
EllGemm() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = GemmKernel::can_implement(
args.problem_size,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
args.ref_C.non_const_ref(),
args.ref_D
);
if (status != Status::kSuccess) {
return status;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{args.ell_blocksize,
ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
tiled_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM;
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return bytes;
}
Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){
// Initialize the Params structure
params_ = typename GemmKernel::Params{
args.problem_size,
grid_shape,
args.ref_A.non_const_ref(),
args.ref_B.non_const_ref(),
args.ref_C.non_const_ref(),
args.ref_D,
args.ell_idx,
args.ell_ncol,
args.ell_blocksize,
args.ell_base_idx,
args.epilogue,
static_cast<int *>(workspace)
};
return Status::kSuccess;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{args.ell_blocksize, ThreadblockShape::kN, ThreadblockShape::kK},
args.split_k_slices);
grid_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM;
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
return set(args, grid_shape, workspace);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A.reset(args.ref_A.non_const_ref().data());
params_.ref_B.reset(args.ref_B.non_const_ref().data());
params_.ref_C.reset(args.ref_C.non_const_ref().data());
params_.ref_D.reset(args.ref_D.data());
params_.output_op = args.epilogue;
params_.semaphore = static_cast<int *>(workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
cudaError_t result;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
/// Parital specialization for column-major output exchanges problem size and operand.
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
/// If true, kernel supports split-K as a serial reduction
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator_,
/// Sparse matrix is A or not
bool IsASparse>
class EllGemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
layout::ColumnMajor, // partially specialized on LayoutC
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
WarpShape_, InstructionShape_, EpilogueOutputOp_,
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
SplitKSerial, Operator_, IsASparse> {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
static bool const kSplitKSerial = SplitKSerial;
static bool const kIsASparse = false;
using UnderlyingOperator = EllGemm<
ElementB,
typename layout::LayoutTranspose<LayoutB>::type,
ElementA,
typename layout::LayoutTranspose<LayoutA>::type,
ElementC,
layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
kAlignmentB,
kAlignmentA,
SplitKSerial,
Operator,
kIsASparse
>;
using UnderlyingArguments = typename UnderlyingOperator::Arguments;
using GemmKernel = typename UnderlyingOperator::GemmKernel;
static int const kAlignmentC = UnderlyingOperator::kAlignmentC;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A;
TensorRef<ElementB const, LayoutB> ref_B;
TensorRef<ElementC const, LayoutC> ref_C;
TensorRef<ElementC, LayoutC> ref_D;
const int* ell_idx;
int ell_ncol;
int ell_blocksize;
int ell_base_idx;
typename EpilogueOutputOp::Params epilogue;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() { }
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementC const, LayoutC> ref_C_,
TensorRef<ElementC, LayoutC> ref_D_,
const int* ell_idx_,
int ell_ncol_,
int ell_blocksize_,
int ell_base_idx_,
typename EpilogueOutputOp::Params epilogue_ =
typename EpilogueOutputOp::Params(),
int split_k_slices = 1
):
problem_size(problem_size_),
ref_A(ref_A_),
ref_B(ref_B_),
ref_C(ref_C_),
ref_D(ref_D_),
ell_idx(ell_idx_),
ell_ncol(ell_ncol_),
ell_blocksize(ell_blocksize_),
ell_base_idx(ell_base_idx_),
epilogue(epilogue_),
split_k_slices(split_k_slices) { }
};
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
EllGemm() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
return UnderlyingArguments(
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
{args.ref_B.data(), args.ref_B.stride(0)},
{args.ref_A.data(), args.ref_A.stride(0)},
{args.ref_C.data(), args.ref_C.stride(0)},
{args.ref_D.data(), args.ref_D.stride(0)},
args.ell_idx,
args.ell_ncol,
args.ell_blocksize,
args.ell_base_idx,
args.epilogue,
args.split_k_slices
);
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK},
args.split_k_slices);
tiled_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN;
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return bytes;
}
Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){
// Initialize the Params structure
return underlying_operator_.set(to_underlying_arguments(args), grid_shape, workspace);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
{ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK},
args.split_k_slices);
grid_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN;
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// Initialize the Params structure
set(args, grid_shape, workspace);
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
\brief Template for a pipelined batch GEMM kernel.
*/
#pragma once

View File

@ -58,6 +58,11 @@ namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*!
GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation
(problem geometry and data references), it can be reused across different GEMM problems having the
geometry. (Once initialized, details regarding problem geometry and references to workspace memory
cannot be updated.)
The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/

View File

@ -109,7 +109,6 @@ public:
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using UnderlyingOperator = GemmUniversalBase<GemmKernel>;
using Arguments = typename UnderlyingOperator::Arguments;
@ -160,10 +159,11 @@ public:
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
/// remain the same.
Status update(Arguments const &args) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
return underlying_operator_.update(to_underlying_arguments(args));
}
/// Runs the kernel using initialized state.

View File

@ -28,15 +28,15 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
\brief The universal GEMM accommodates streamk, batched strided, and batched array variants.
*/
#pragma once
//#include <limits>
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@ -44,7 +44,6 @@
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
@ -52,7 +51,7 @@
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
@ -67,7 +66,7 @@ public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
@ -83,7 +82,8 @@ public:
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
/// Numerical accumulation element type
using ElementAccumulator = typename GemmKernel::Mma::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
@ -94,316 +94,285 @@ public:
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
//
// Device properties (uniform across all instances of the current thread)
//
protected:
// Device ordinal
thread_local static int device_ordinal_;
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) {
/// Device SM count
thread_local static int device_sms_;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
/// Kernel SM occupancy (in thread blocks)
thread_local static int sm_occupancy_;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
/// Initialize static thread-local members for the thread's current device,
/// if necessary.
static Status init_device_props()
{
CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()");
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
cudaError_t cudart_result;
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBase() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
// Split-K parallel always requires a temporary workspace
workspace_bytes =
sizeof(ElementC) *
size_t(args.batch_stride_D) *
size_t(grid_tiled_shape.k());
}
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
// Get current device ordinal
int current_ordinal;
cudart_result = cudaGetDevice(&current_ordinal);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
// Done if matches the current static member
if (current_ordinal == device_ordinal_) {
// Already initialized
return Status::kSuccess;
}
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
return workspace_bytes;
}
// Update SM count member
cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(
" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
int max_active_blocks = -1;
// Update the kernel function's shared memory configuration for the current device
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
// Requires more than 48KB: configure for extended, dynamic shared memory
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel<GemmKernel>,
GemmKernel::kThreadCount,
cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
}
else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel<GemmKernel>,
GemmKernel::kThreadCount,
0);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
return -1;
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
if (smem_capacity < 0) {
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes) {
if (!workspace) {
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm) {
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(
args,
grid_tiled_shape,
gemm_k_size,
static_cast<int *>(workspace)
);
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); // 100% shared memory
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBase()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
// Update SM occupancy member
cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
&sm_occupancy_,
Kernel2<GemmKernel>,
GemmKernel::kThreadCount,
int(sizeof(typename GemmKernel::SharedStorage)),
cudaOccupancyDisableCachingOverride);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
}
params_.update(args, workspace);
// Update device ordinal member on success
device_ordinal_ = current_ordinal;
CUTLASS_TRACE_HOST(" "
"device_ordinal: (" << device_ordinal_ << "), "
"device_sms: (" << device_sms_ << "), "
"sm_occupancy: (" << sm_occupancy_ << ")");
return Status::kSuccess;
}
protected:
//
// Instance data members
//
/// Kernel parameters
typename GemmKernel::Params params_;
/// Initialize params member
Status init_params(Arguments const &args)
{
// Initialize static device properties, if necessary
Status result = init_device_props();
if (result != Status::kSuccess) {
return result;
}
// Initialize params member
params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_);
return Status::kSuccess;
}
public:
//---------------------------------------------------------------------------------------------
// Stateless API
//---------------------------------------------------------------------------------------------
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()");
// Initialize static kernel and device properties, if necessary.
Status result = init_device_props();
if (result != Status::kSuccess) {
return result;
}
dim3 grid = get_grid_shape(args);
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
grid.z <= std::numeric_limits<uint16_t>::max()))
{
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Returns the workspace size (in bytes) needed for the problem
/// geometry expressed by these arguments
static size_t get_workspace_size(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
// Initialize parameters from args
GemmUniversalBase base;
if (base.init_params(args) != Status::kSuccess) {
return 0;
}
// Get size from parameters
size_t workspace_bytes = base.params_.get_workspace_size();
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
return workspace_bytes;
}
/// Returns the grid extents in thread blocks to launch
static dim3 get_grid_shape(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
// Initialize parameters from args
GemmUniversalBase base;
if (base.init_params(args) != Status::kSuccess) {
return dim3(0,0,0);
}
// Get dims from parameters
dim3 grid_dims = base.params_.get_grid_dims();
CUTLASS_TRACE_HOST(
" tiled_shape: " << base.params_.get_tiled_shape() << "\n"
<< " grid_dims: {" << grid_dims << "}");
return grid_dims;
}
/// Returns the maximum number of active thread blocks per multiprocessor
static int maximum_active_blocks()
{
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
// Initialize static device properties, if necessary
if (init_device_props() != Status::kSuccess) {
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
return sm_occupancy_;
}
//---------------------------------------------------------------------------------------------
// Stateful API
//---------------------------------------------------------------------------------------------
/// Initializes GEMM state from arguments and workspace memory
Status initialize(
Arguments const &args,
void *workspace,
cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize parameters from args
Status result = init_params(args);
if (result != Status::kSuccess) {
return result;
}
// Assign and prepare workspace memory
return params_.init_workspace(workspace, stream);
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
/// remain the same.
Status update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");
params_.update(args);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
Status run(cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
//
// Configure grid and block dimensions
//
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
dim3 block(GemmKernel::kThreadCount, 1, 1);
dim3 grid = params_.get_grid_dims();
//
// Launch kernel
//
CUTLASS_TRACE_HOST(" "
"grid: (" << grid << "), "
"block: (" << block << "), "
"SMEM: (" << smem_size << ")");
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block
<< "), SMEM: " << smem_size << " bytes");
Kernel2<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
// Launch
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
@ -412,6 +381,24 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Static initializers
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Device ordinal
template <typename GemmKernel_>
thread_local int GemmUniversalBase<GemmKernel_>::device_ordinal_ = -1;
/// Device SM count
template <typename GemmKernel_>
thread_local int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
/// Kernel SM occupancy (in thread blocks)
template <typename GemmKernel_>
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device

View File

@ -28,8 +28,10 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
\brief Template for a GEMM kernel that can broadcast bias vector in the
epigloue.
*/
#pragma once
@ -45,7 +47,7 @@
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
@ -97,7 +99,7 @@ template <
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
ElementC_, ElementAccumulator_, ElementAccumulator_,
ElementC_, ElementC_, 16 / sizeof(ElementC_)>,
ElementC_, ElementC_, 128 / cutlass::sizeof_bits<ElementC_>::value>,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
@ -123,7 +125,7 @@ template <
>
class GemmUniversalWithBroadcast :
public GemmUniversalBase<
typename kernel::DefaultGemmWithBroadcastV2<
typename kernel::DefaultGemmWithBroadcast<
ElementA_,
LayoutA_,
TransformA,
@ -166,7 +168,7 @@ class GemmUniversalWithBroadcast :
static ComplexTransform const kTransformB = TransformB;
using Base = GemmUniversalBase<
typename kernel::DefaultGemmWithBroadcastV2<
typename kernel::DefaultGemmWithBroadcast<
ElementA_,
LayoutA_,
TransformA,

View File

@ -29,7 +29,8 @@
*
**************************************************************************************************/
/*! \file
\brief
\brief Template for a GEMM kernel that can reduce one of the input matrix
into a vector along the K dimension.
*/
#pragma once

View File

@ -0,0 +1,837 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Default kernel-level Blocked-Ell sparse gemm operators.
This operator combines threadblock-scoped ELL MMA
with the appropriate threadblock-scoped epilogue.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/wmma.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm.h"
#include "cutlass/gemm/kernel/gemm_pipelined.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
#endif //CUTLASS_ARCH_WMMA_ENABLED
#include "cutlass/gemm/kernel/ell_gemm.h"
#include "cutlass/gemm/threadblock/default_ell_mma.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse>
struct DefaultEllGemm;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
struct DefaultEllGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator, IsASparse> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
struct DefaultEllGemm<
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementC, layout::RowMajor,
ElementAccumulator,
arch::OpClassTensorOp,
arch::Sm75,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator,
IsASparse
> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm75,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator
>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape,
typename Mma::Operator,
kPartitionsK,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
template <
/// Element type for A matrix operand
typename ElementA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse>
struct DefaultEllGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB, ElementC,
layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape,
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
SplitKSerial, Operator, IsASparse> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
using ElementAccumulator = int32_t;
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape, Stages, Operator,
true>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::
DefaultInterleavedEpilogueTensorOp<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Integer Matrix Multiply Interleaved layout
template <
/// Element type for A matrix operand
typename ElementA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse>
struct DefaultEllGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape,
WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, 2, SplitKSerial, Operator, IsASparse> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
using ElementAccumulator = int32_t;
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape,
InstructionShape, 2, Operator, true>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::
DefaultInterleavedEpilogueTensorOp<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Volta architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
struct DefaultEllGemm<
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementC, layout::RowMajor,
ElementAccumulator,
arch::OpClassTensorOp,
arch::Sm70,
ThreadblockShape,
WarpShape,
GemmShape<8, 8, 4>,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator,
IsASparse
> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm70,
ThreadblockShape,
WarpShape,
GemmShape<8, 8, 4>,
2,
Operator
>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp<
ThreadblockShape,
typename Mma::Operator,
kPartitionsK,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for SIMT
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
struct DefaultEllGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
layout::RowMajor,
ElementAccumulator,
arch::OpClassSimt,
ArchTag,
ThreadblockShape,
WarpShape,
GemmShape<1, 1, 1>,
EpilogueOutputOp,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator,
IsASparse> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassSimt,
arch::Sm50,
ThreadblockShape,
WarpShape,
GemmShape<1, 1, 1>,
2,
Operator>::ThreadblockMma;
static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
ThreadblockShape,
typename Mma::Operator,
EpilogueOutputOp,
kEpilogueElementsPerAccess
>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages
int Stages,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
struct DefaultEllGemm<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
layout::RowMajor,
ElementAccumulator,
arch::OpClassSimt,
arch::Sm80,
ThreadblockShape,
WarpShape,
GemmShape<1, 1, 1>,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
SplitKSerial,
Operator,
IsASparse> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm80,
ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages,
Operator>::ThreadblockMma;
static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
ThreadblockShape,
typename Mma::Operator,
EpilogueOutputOp,
kEpilogueElementsPerAccess
>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial,IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for SIMT DP4A
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Layout type for C matrix operand
typename LayoutC,
/// Element type for C and D matrix operands
typename ElementC,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
struct DefaultEllGemm<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt,
ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>,
EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial,
Operator, IsASparse> {
using InstructionShape = GemmShape<1, 1, 4>;
using ElementA = int8_t;
using ElementB = int8_t;
using OperatorClass = arch::OpClassSimt;
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
arch::OpClassSimt,
arch::Sm50,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator
>::ThreadblockMma;
static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
ThreadblockShape,
typename Mma::Operator,
EpilogueOutputOp,
kEpilogueElementsPerAccess
>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Wmma Gemm Kernel
template <
///< Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Sparse matrix is A or not
bool IsASparse
>
struct DefaultEllGemm<
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementC, LayoutC,
ElementAccumulator,
arch::OpClassWmmaTensorOp,
ArchTag,
ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
SplitKSerial,
Operator,
IsASparse> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC,
arch::OpClassWmmaTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
Stages,
Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp<
ThreadblockShape,
typename Mma::Operator,
kPartitionsK,
EpilogueOutputOp,
EpilogueOutputOp::kCount
>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
};
////////////////////////////////////////////////////////////////////////////////
#endif //CUTLASS_ARCH_WMMA_ENABLED
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -135,6 +135,77 @@ template <
struct DefaultGemm;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB,
/// Scatter result D by using an index array
bool ScatterD,
/// Permute result D
typename PermuteDLayout
>
struct DefaultGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, PermuteDLayout> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture

View File

@ -119,6 +119,66 @@ struct DefaultGemmComplex;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Multiply-add operator
// (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
typename Operator,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial
>
struct DefaultGemmComplex<
ElementA, LayoutA, ElementB, LayoutB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape,
WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator>::Epilogue;
/// Define the kernel-level GEMM operator.
using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture
template <
/// Element type for A matrix operand

View File

@ -49,6 +49,7 @@
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/gemm_universal_streamk.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
@ -227,12 +228,26 @@ struct DefaultGemmUniversal<
PermuteDLayout
>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::GemmUniversal<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle
>;
/// Universal kernel without StreamkFeature member type
template <class SwizzleT, class Enable = void>
class SelectBase :
public kernel::GemmUniversal<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT>
{};
/// Universal kernel with StreamkFeature member type
template <class SwizzleT>
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> :
public kernel::GemmUniversalStreamk<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT>
{};
/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
using GemmKernel = SelectBase<ThreadblockSwizzle>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -336,12 +351,26 @@ struct DefaultGemmUniversal<
false
>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::GemmUniversal<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle
>;
/// Universal kernel without StreamkFeature member type
template <class SwizzleT, class Enable = void>
class SelectBase :
public kernel::GemmUniversal<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT>
{};
/// Universal kernel with StreamkFeature member type
template <class SwizzleT>
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> :
public kernel::GemmUniversalStreamk<
typename DefaultGemmKernel::Mma,
typename DefaultGemmKernel::Epilogue,
SwizzleT>
{};
/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
using GemmKernel = SelectBase<ThreadblockSwizzle>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,242 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Defines a GEMM with Reduction based on an existing UniversalGemm kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
///
typename Enable = void
>
struct DefaultGemmWithBroadcastV2 {
using GemmBase = typename DefaultGemmUniversal<
ElementA_, LayoutA_, TransformA, kAlignmentA,
ElementB_, LayoutB_, TransformB, kAlignmentB,
ElementC_, LayoutC_, ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator
>::GemmKernel;
// Replace epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOpV2<
typename GemmBase::Epilogue::Shape,
typename GemmBase::Epilogue::WarpMmaOperator,
GemmBase::Epilogue::kPartitionsK,
ElementC_,
typename EpilogueOutputOp::ElementT,
ElementC_,
EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess
>::Epilogue;
// Compose the GEMM kernel
using GemmKernel = GemmWithFusedEpilogueV2<
typename GemmBase::Mma,
Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Parital specialization: ArchTag = cutlass::arch::Sm70
///
///
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
///
typename Enable
>
struct DefaultGemmWithBroadcastV2<
ElementA_, LayoutA_, TransformA, kAlignmentA,
ElementB_, LayoutB_, TransformB, kAlignmentB,
ElementC_, LayoutC_,
ElementAccumulator,
OperatorClass,
cutlass::arch::Sm70,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator,
Enable
> {
using GemmBase = typename DefaultGemmUniversal<
ElementA_, LayoutA_, TransformA, kAlignmentA,
ElementB_, LayoutB_, TransformB, kAlignmentB,
ElementC_, LayoutC_, ElementAccumulator,
OperatorClass,
cutlass::arch::Sm70,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator
>::GemmKernel;
// Replace epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOpV2<
typename GemmBase::Epilogue::Shape,
typename GemmBase::Epilogue::WarpMmaOperator,
GemmBase::Epilogue::kPartitionsK,
ElementC_,
typename EpilogueOutputOp::ElementT,
ElementC_,
EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess
>::Epilogue;
// Compose the GEMM kernel
using GemmKernel = GemmWithFusedEpilogueV2<
typename GemmBase::Mma,
Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -120,6 +120,84 @@ template <
BlasMode BlasMode_ = BlasMode::kSymmetric>
struct DefaultRank2K;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Fill Mode for C (kLower or kUpper)
FillMode FillModeC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultRank2K<
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementC,layout::RowMajor, FillModeC,
ElementAccumulator, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
/// Define the threadblock-scoped matrix multiply-accumulate (A x BT)
using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA,
kAlignmentA,
ElementB, typename layout::LayoutTranspose<LayoutB>::type,
kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
/// Define the threadblock-scoped matrix multiply-accumulate (B x AT)
using Mma2 = typename cutlass::gemm::threadblock::DefaultMma<
ElementB, LayoutB,
kAlignmentB,
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
kAlignmentA,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3<
ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue;
/// Define the kernel-level Rank2K operator.
using Rank2Kkernel = kernel::Rank2KUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, FillModeC, BlasMode::kSymmetric>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -163,6 +163,170 @@ template <>
};
}
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture complex datatype (symmetric)
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Fill Mode for C (kLower or kUpper)
FillMode FillModeC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Operation performed by GEMM
typename Operator,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial>
struct DefaultRank2KComplex<
ElementA, LayoutA, ElementB, LayoutB, ElementC,
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages,
TransformA, TransformB, Operator, SplitKSerial, BlasMode::kSymmetric> {
static BlasMode const kBlasMode = BlasMode::kSymmetric;
/// Define the threadblock-scoped matrix multiply-accumulate (A x B^T)
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
ElementA, LayoutA,
ElementB, typename layout::LayoutTranspose<LayoutB>::type,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
TransformA, TransformB, Operator>::ThreadblockMma;
/// Define the threadblock-scoped matrix multiply-accumulate (B x A^T)
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
ElementB, LayoutB,
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
TransformA, TransformB, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
/// Define the kernel-level Rank2K operator.
using Rank2Kkernel = kernel::Rank2KUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, FillModeC, kBlasMode>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture complex datatype (hermitian)
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Fill Mode for C (kLower or kUpper)
FillMode FillModeC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Operation performed by GEMM
typename Operator,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial>
struct DefaultRank2KComplex<
ElementA, LayoutA, ElementB, LayoutB, ElementC,
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages,
TransformA, TransformB, Operator, SplitKSerial, BlasMode::kHermitian> {
static BlasMode const kBlasMode = BlasMode::kHermitian;
// Complex transform for input A and B matrices (function on input layout)
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using TransposedComplexTransform = detail::Rank2KTransposedComplexTransform<
LayoutA, LayoutB,
TransformA, TransformB,
kBlasMode>;
// Complex transform on operandA and operandB (function of blas3 computation)
static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA;
static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB;
/// Define the threadblock-scoped matrix multiply-accumulate (A x B^H)
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
ElementA, LayoutA,
ElementB, typename layout::LayoutTranspose<LayoutB>::type,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma;
/// Define the threadblock-scoped matrix multiply-accumulate (B x A^H)
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
ElementB, LayoutB,
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
/// Define the kernel-level Rank2K operator.
using Rank2Kkernel = kernel::Rank2KUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, FillModeC, kBlasMode>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture complex datatype (symmetric)

View File

@ -114,6 +114,68 @@ template <
BlasMode BlasMode_ = BlasMode::kSymmetric>
struct DefaultRankK;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for C and D matrix operands
typename ElementC,
/// Fill Mode for C (kLower or kUpper)
FillMode FillModeC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultRankK<
ElementA, LayoutA, kAlignmentA,
ElementC,layout::RowMajor, FillModeC,
ElementAccumulator, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
/// Define the threadblock-scoped matrix multiply-accumulate (A x AT)
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
ElementA, LayoutA,
kAlignmentA,
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
kAlignmentA,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue;
/// Define the kernel-level Rank2 operator.
using RankKkernel = kernel::RankKUniversal<Mma, Epilogue, ThreadblockSwizzle, FillModeC>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -155,6 +155,140 @@ template <>
}
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture complex datatype (symmetric)
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for C and D matrix operands
typename ElementC,
/// Fill Mode for C (kLower or kUpper)
FillMode FillModeC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Operation performed by GEMM
typename Operator,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial>
struct DefaultRankKComplex<
ElementA, LayoutA, ElementC,
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages,
TransformA, Operator, SplitKSerial, BlasMode::kSymmetric> {
static BlasMode const kBlasMode = BlasMode::kSymmetric;
/// Define the threadblock-scoped matrix multiply-accumulate (A x B^T)
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
ElementA, LayoutA,
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
TransformA, TransformA, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
/// Define the kernel-level RankK operator.
using RankKkernel = kernel::RankKUniversal<Mma, Epilogue, ThreadblockSwizzle, FillModeC>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture complex datatype (hermitian)
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for C and D matrix operands
typename ElementC,
/// Fill Mode for C (kLower or kUpper)
FillMode FillModeC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Operation performed by GEMM
typename Operator,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial>
struct DefaultRankKComplex<
ElementA, LayoutA, ElementC,
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages,
TransformA, Operator, SplitKSerial, BlasMode::kHermitian> {
static BlasMode const kBlasMode = BlasMode::kHermitian;
// Complex transform for input A and B matrices (function on input layout)
static ComplexTransform const kTransformA = TransformA;
using TransposedComplexTransform = detail::RankKTransposedComplexTransform<
LayoutA,
TransformA,
kBlasMode>;
// Complex transform on operandA and operandB (function of blas3 computation)
static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA;
static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB;
/// Define the threadblock-scoped matrix multiply-accumulate (A x A^H)
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
ElementA, LayoutA,
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
/// Define the kernel-level RankK operator.
using RankKkernel = kernel::RankKUniversal<Mma, Epilogue, ThreadblockSwizzle, FillModeC>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture complex datatype (symmetric)
template <
/// Element type for A matrix operand

View File

@ -123,6 +123,101 @@ template <
BlasMode BlasMode_ = BlasMode::kSymmetric>
struct DefaultSymm;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Side Mode for A (kLeft or kRight)
SideMode kSideModeA,
/// Fill Mode for A (kLower or kUpper)
FillMode kFillModeA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultSymm<
ElementA, LayoutA, kSideModeA, kFillModeA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementC,layout::RowMajor,
ElementAccumulator, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
/// Define the threadblock-scoped triagular matrix multiply-accumulate
/// TRMM - with diagonal: alpha * A * B or alpha * B * A
static const DiagType kDiagTypeMma1 = DiagType::kNonUnit;
using Mma1 = typename cutlass::gemm::threadblock::DefaultTrmm<
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
kSideModeA, kFillModeA, kDiagTypeMma1,
ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
Stages, Operator>::ThreadblockMma;
/// Define the threadblock-scoped triagular matrix multiply-accumulate
/// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT
static const DiagType kDiagTypeMma2 = DiagType::kZero;
using LayoutAMma2 = typename platform::conditional<
(kSideModeA == SideMode::kLeft),
typename layout::LayoutTranspose<LayoutA>::type,
LayoutA
>::type;
using LayoutBMma2 = typename platform::conditional<
(kSideModeA == SideMode::kLeft),
LayoutB,
typename layout::LayoutTranspose<LayoutB>::type
>::type;
using Mma2 = typename cutlass::gemm::threadblock::DefaultTrmm<
ElementA, LayoutAMma2, kAlignmentA,
ElementB, LayoutBMma2, kAlignmentB,
kSideModeA, InvertFillMode<kFillModeA>::mode, kDiagTypeMma2,
ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
Stages, Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount>::Epilogue;
/// Define the kernel-level SYMM/HEMM operator.
using SymmKernel = kernel::SymmUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, kSideModeA, kFillModeA>;
};
////////////////////////////////////////////////////////////////////////////////
@ -221,7 +316,6 @@ struct DefaultSymm<
};
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -117,6 +117,199 @@ struct DefaultSymmComplex;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture complex datatype (symmetric)
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Side Mode for A (kLeft or kRight)
SideMode kSideModeA,
/// Fill Mode for A (kLower or kUpper)
FillMode kFillModeA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial>
struct DefaultSymmComplex<
ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages,
Operator, SplitKSerial, BlasMode::kSymmetric> {
static BlasMode const kBlasMode = BlasMode::kSymmetric;
// Complex Transform don't appply to A or B for SYMM
static ComplexTransform const TransformA = ComplexTransform::kNone;
static ComplexTransform const TransformB = ComplexTransform::kNone;
/// Define the threadblock-scoped triagular matrix multiply-accumulate
/// TRMM - with diagonal: alpha * A * B or alpha * B * A
static const DiagType kDiagTypeMma1 = DiagType::kNonUnit;
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
ElementA, LayoutA,
ElementB, LayoutB,
kSideModeA, kFillModeA, kDiagTypeMma1,
ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
Stages, TransformA, TransformB, Operator>::ThreadblockMma;
/// Define the threadblock-scoped triagular matrix multiply-accumulate
/// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT
static const DiagType kDiagTypeMma2 = DiagType::kZero;
using LayoutAMma2 = typename platform::conditional<
(kSideModeA == SideMode::kLeft),
typename layout::LayoutTranspose<LayoutA>::type,
LayoutA
>::type;
using LayoutBMma2 = typename platform::conditional<
(kSideModeA == SideMode::kLeft),
LayoutB,
typename layout::LayoutTranspose<LayoutB>::type
>::type;
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
ElementA, LayoutAMma2,
ElementB, LayoutBMma2,
kSideModeA, InvertFillMode<kFillModeA>::mode, kDiagTypeMma2,
ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
Stages, TransformA, TransformB, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator>::Epilogue;
/// Define the kernel-level Symm operator.
using SymmKernel = kernel::SymmUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, kSideModeA, kFillModeA>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture complex datatype (hermitian)
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Side Mode for A (kLeft or kRight)
SideMode kSideModeA,
/// Fill Mode for A (kLower or kUpper)
FillMode kFillModeA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial>
struct DefaultSymmComplex<
ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages,
Operator, SplitKSerial, BlasMode::kHermitian> {
static BlasMode const kBlasMode = BlasMode::kHermitian;
/// Define the threadblock-scoped triagular matrix multiply-accumulate
/// TRMM - with diagonal: alpha * A * B or alpha * B * A
static const DiagType kDiagTypeMma1 = DiagType::kNonUnit;
static ComplexTransform const TransformAMma1 = ComplexTransform::kNone;
static ComplexTransform const TransformBMma1 = ComplexTransform::kNone;
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
ElementA, LayoutA,
ElementB, LayoutB,
kSideModeA, kFillModeA, kDiagTypeMma1,
ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
Stages, TransformAMma1, TransformBMma1, Operator, BlasMode::kHermitian>::ThreadblockMma;
/// Define the threadblock-scoped triagular matrix multiply-accumulate
/// TRMM - withOUT diagonal - with conjugate transpose: alpha * AT * B or alpha * B * AT
static const DiagType kDiagTypeMma2 = DiagType::kZero;
using LayoutAMma2 = typename platform::conditional<
(kSideModeA == SideMode::kLeft),
typename layout::LayoutTranspose<LayoutA>::type,
LayoutA
>::type;
using LayoutBMma2 = typename platform::conditional<
(kSideModeA == SideMode::kLeft),
LayoutB,
typename layout::LayoutTranspose<LayoutB>::type
>::type;
static ComplexTransform const TransformAMma2 = (kSideModeA == SideMode::kLeft) ?
ComplexTransform::kConjugate : ComplexTransform::kNone;
static ComplexTransform const TransformBMma2 = (kSideModeA == SideMode::kLeft) ?
ComplexTransform::kNone : ComplexTransform::kConjugate;
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
ElementA, LayoutAMma2,
ElementB, LayoutBMma2,
kSideModeA, InvertFillMode<kFillModeA>::mode, kDiagTypeMma2,
ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape,
Stages, TransformAMma2, TransformBMma2, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator>::Epilogue;
/// Define the kernel-level Symm operator.
using SymmKernel = kernel::SymmUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, kSideModeA, kFillModeA>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture complex datatype (symmetric)
template <
/// Element type for A matrix operand
@ -310,7 +503,6 @@ struct DefaultSymmComplex<
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -124,6 +124,76 @@ struct DefaultTrmm;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Side Mode for the kernel
SideMode kSideMode,
/// Fill Mode for the triangular matrix
FillMode kFillMode,
/// Diag Type for the triangular matrix
DiagType kDiagType,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultTrmm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
kSideMode, kFillMode, kDiagType, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
Operator> {
/// Define the threadblock-scoped triagular matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultTrmm<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
kSideMode, kFillMode, kDiagType,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator>::ThreadblockMma;
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount>::Epilogue;
/// Define the kernel-level TRMM operator.
using TrmmKernel = kernel::TrmmUniversal<Mma, Epilogue, ThreadblockSwizzle, kSideMode, kFillMode, kDiagType>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture
template <
/// Element type for A matrix operand

View File

@ -122,6 +122,74 @@ struct DefaultTrmmComplex;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Hopper Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Side Mode for the kernel
SideMode kSideMode,
/// Fill Mode for the triangular matrix
FillMode kFillMode,
/// Diag Type for the triangular matrix
DiagType kDiagType,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Multiply-add operator
// (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
typename Operator,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial
>
struct DefaultTrmmComplex<
ElementA, LayoutA, ElementB, LayoutB,
kSideMode, kFillMode, kDiagType,
ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> {
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
ElementA, LayoutA, ElementB, LayoutB,
kSideMode, kFillMode, kDiagType,
ElementAccumulator,layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape,
WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
EpilogueOutputOp::kCount, Operator>::Epilogue;
/// Define the kernel-level TRMM operator.
using TrmmKernel = kernel::TrmmUniversal<Mma, Epilogue, ThreadblockSwizzle, kSideMode, kFillMode, kDiagType>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Ampere Architecture
template <
/// Element type for A matrix operand

View File

@ -0,0 +1,830 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a Block-Ell sparse gemm kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/arch/arch.h"
#include "cutlass/transform/threadblock/ell_iterator.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled.
bool IsASparse ///! If true, A is sparse matrix
>
struct EllGemm {
using Mma = Mma_;
using Epilogue = Epilogue_;
using OutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename OutputOp::Params output_op;
int *semaphore;
int gemm_k_iterations;
int gemm_k_size;
const int* ell_idx;
int ell_ncol;
int ell_blocksize;
int ell_base_idx;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
const int* ell_idx,
int ell_ncol,
int ell_blocksize,
int ell_base_idx,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
params_C(ref_C.layout()),
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
output_op(output_op),
ell_idx(ell_idx),
ell_ncol(ell_ncol),
ell_blocksize(ell_blocksize),
ell_base_idx(ell_base_idx)
{
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
struct SharedStorage {
union{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
typename cutlass::transform::threadblock::ell::SharedStorage ell;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
EllGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D) {
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kM - 1 ) / Mma::Shape::kM;
int ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block;
int tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
typename Mma::FragmentC accumulators;
accumulators.clear();
// skip computation if matrix is 0
if (params.ell_ncol > 0) {
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
ell_block_offset_m * params.ell_blocksize
+ tile_offset_m * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size,
threadblock_tile_offset.n() * Mma::Shape::kN
};
int ell_idx_start =
(threadblock_tile_offset.m() / tile_in_ell_block) *
(params.ell_ncol / params.ell_blocksize);
const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]);
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(
params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
problem_size_k = min(problem_size_k, params.ell_ncol);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Define coef for ELL index depending on LayoutB
int ell_stride = iterator_B.get_stride();
typename cutlass::transform::threadblock::ell::Iterator ell_iterator(
shared_storage.ell,
ell_idx_ptr,
params.ell_blocksize,
params.ell_base_idx,
Mma::Shape::kK,
problem_size_k,
ell_stride,
thread_idx
);
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
if (!kSplitKSerial || gemm_k_iterations > 0) {
// check if index computations can be skipped
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8);
constexpr bool is_multiple_alignment =
(kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1);
const bool is_specialized_blocksize =
((params.ell_blocksize) & (params.ell_blocksize-1)) == 0
&& params.ell_blocksize >= Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
if ((is_double || is_multiple_alignment) && is_specialized_blocksize) {
mma.operator()<true, true>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
else {
mma.operator()<true, false>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
}
} // if (params.ell_ncols > 0)
//
// Epilogue
//
OutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block;
tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block;
//assume identity swizzle
MatrixCoord threadblock_offset(
ell_block_offset_m * params.ell_blocksize
+ tile_offset_m * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
//avoid out of bounds
MatrixCoord threadblock_extent(
min(params.problem_size.m(),
ell_block_offset_m * params.ell_blocksize
+ min((tile_offset_m + 1) * Mma::Shape::kM, params.ell_blocksize)),
min(params.problem_size.n(),
(threadblock_tile_offset.n()+1) * Mma::Shape::kN)
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
// B is Sparse
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct EllGemm<Mma_, Epilogue_, ThreadblockSwizzle_, SplitKSerial, false> {
using Mma = Mma_;
using Epilogue = Epilogue_;
using OutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename OutputOp::Params output_op;
int *semaphore;
int gemm_k_iterations;
int gemm_k_size;
const int* ell_idx;
int ell_ncol;
int ell_blocksize;
int ell_base_idx;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
const int* ell_idx,
int ell_ncol,
int ell_blocksize,
int ell_base_idx,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
params_C(ref_C.layout()),
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
output_op(output_op),
ell_idx(ell_idx),
ell_ncol(ell_ncol),
ell_blocksize(ell_blocksize),
ell_base_idx(ell_base_idx)
{
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
struct SharedStorage {
union{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
typename cutlass::transform::threadblock::ell::SharedStorage ell;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
EllGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D) {
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kN - 1 ) / Mma::Shape::kN;
int ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block;
int tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
typename Mma::FragmentC accumulators;
accumulators.clear();
// skip computation if matrix is 0
if (params.ell_ncol > 0) {
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size,
ell_block_offset_n * params.ell_blocksize
+ tile_offset_n * Mma::Shape::kN,
};
int ell_idx_start =
(threadblock_tile_offset.n() / tile_in_ell_block) *
(params.ell_ncol / params.ell_blocksize);
const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]);
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(
params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
problem_size_k = min(problem_size_k, params.ell_ncol);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Define coef for ELL index depending on LayoutA
int ell_stride = iterator_A.get_stride();
typename cutlass::transform::threadblock::ell::Iterator ell_iterator(
shared_storage.ell,
ell_idx_ptr,
params.ell_blocksize,
params.ell_base_idx,
Mma::Shape::kK,
problem_size_k,
ell_stride,
thread_idx
);
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
if (!kSplitKSerial || gemm_k_iterations > 0) {
// check if index computations can be skipped
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8);
constexpr bool is_multiple_alignment =
(kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1);
const bool is_specialized_blocksize =
((params.ell_blocksize) & (params.ell_blocksize-1)) == 0
&& params.ell_blocksize >= Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
if ((is_double || is_multiple_alignment) && is_specialized_blocksize) {
mma.operator()<false, true>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
else {
mma.operator()<false, false>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
}
} // if (params.ell_ncols > 0)
//
// Epilogue
//
OutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block;
tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block;
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
ell_block_offset_n * params.ell_blocksize
+ tile_offset_n * Mma::Shape::kN
);
//avoid out of bounds
MatrixCoord threadblock_extent(
min(params.problem_size.m(),
(threadblock_tile_offset.m()+1) * Mma::Shape::kM),
min(params.problem_size.n(),
ell_block_offset_n * params.ell_blocksize
+ min((tile_offset_n + 1) * Mma::Shape::kN, params.ell_blocksize))
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -315,13 +315,6 @@ public:
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
static size_t get_extra_workspace_size(
Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE

View File

@ -50,11 +50,22 @@ namespace kernel {
namespace detail {
// Helper for correctly representing problem sizes in grouped kernels
template <bool Transposed>
template <
typename ThreadblockShape,
bool Transposed
>
struct GemmGroupedProblemSizeHelper {
static bool const kTransposed = Transposed;
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) {
return cutlass::gemm::GemmCoord(
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN),
1);
}
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {
if (kTransposed) {
@ -77,7 +88,7 @@ template <typename ThreadblockShape,
int ThreadCount,
bool Transposed = false>
struct GemmGroupedProblemVisitor : public GroupedProblemVisitor<
detail::GemmGroupedProblemSizeHelper<Transposed>,
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
@ -85,7 +96,7 @@ struct GemmGroupedProblemVisitor : public GroupedProblemVisitor<
static bool const kTransposed = Transposed;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<Transposed>;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using Params = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;

View File

@ -335,13 +335,6 @@ public:
return Status::kSuccess;
}
static size_t get_extra_workspace_size(
Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {

View File

@ -41,6 +41,7 @@
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
#include "cutlass/layout/matrix.h"
@ -104,16 +105,12 @@ public:
//
/// Argument structure
struct Arguments {
struct Arguments : UniversalArgumentsBase
{
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
void const * ptr_A;
@ -132,7 +129,6 @@ public:
int64_t batch_stride_gamma;
int64_t batch_stride_beta;
int64_t batch_stride_C;
int64_t batch_stride_D;
typename LayoutA::Stride stride_a;
typename LayoutB::Stride stride_b;
@ -161,14 +157,13 @@ public:
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
ptr_var(nullptr), ptr_mean(nullptr),
ptr_gamma(nullptr), ptr_beta(nullptr),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr) {}
ptr_scatter_D_indices(nullptr)
{}
/// constructs an arguments structure
Arguments(
@ -202,31 +197,27 @@ public:
typename LayoutC::Stride stride_d,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
int const *ptr_scatter_D_indices = nullptr)
:
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
ptr_var(ptr_var), ptr_mean(ptr_mean),
ptr_gamma(ptr_gamma), ptr_beta(ptr_beta),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean),
batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta),
lda(0), ldb(0), ldc(0), ldd(0),
ld_var(0), ld_mean(0),
ld_gamma(0), ld_beta(0),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
stride_var(stride_var), stride_mean(stride_mean),
stride_gamma(stride_gamma), stride_beta(stride_beta),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
lda = 0;
ldb = 0;
ldc = 0;
ldd = 0;
ld_var = 0;
ld_mean = 0;
ld_gamma = 0;
ld_beta = 0;
ptr_scatter_D_indices(ptr_scatter_D_indices)
{
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
}
/// constructs an arguments structure
Arguments(
@ -260,23 +251,22 @@ public:
typename LayoutC::Stride::LongIndex ldd,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
int const *ptr_scatter_D_indices = nullptr)
:
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
ptr_var(ptr_var), ptr_mean(ptr_mean),
ptr_gamma(ptr_gamma), ptr_beta(ptr_beta),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean),
batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
ld_var(ld_var), ld_mean(ld_mean),
ld_gamma(ld_gamma), ld_beta(ld_beta),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
ptr_scatter_D_indices(ptr_scatter_D_indices)
{
stride_a = make_Coord(lda);
stride_b = make_Coord(ldb);
stride_c = make_Coord(ldc);
@ -286,7 +276,7 @@ public:
stride_gamma = make_Coord(ld_gamma);
stride_beta = make_Coord(ld_beta);
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
@ -303,17 +293,30 @@ public:
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>
{
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
//
// Data members
//
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Epilogue::OutputTileIterator::Params params_C;
@ -321,10 +324,6 @@ public:
typename EpilogueOutputOp::Params output_op;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A;
void * ptr_B;
void * ptr_var;
@ -341,65 +340,30 @@ public:
int64_t batch_stride_gamma;
int64_t batch_stride_beta;
int64_t batch_stride_C;
int64_t batch_stride_D;
int * ptr_gather_A_indices;
int * ptr_gather_B_indices;
int * ptr_scatter_D_indices;
int *semaphore;
//
// Methods
// Host dispatch API
//
CUTLASS_HOST_DEVICE
Params():
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_var(nullptr),
ptr_mean(nullptr),
ptr_gamma(nullptr),
ptr_beta(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_var(0),
batch_stride_mean(0),
batch_stride_C(0),
batch_stride_D(0),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr),
semaphore(nullptr) { }
/// Default constructor
Params() = default;
CUTLASS_HOST_DEVICE
/// Constructor
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
Arguments const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
ParamsBase(args, device_sms, sm_occupancy),
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
output_op(args.epilogue),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_var(const_cast<void *>(args.ptr_var)),
@ -415,19 +379,15 @@ public:
batch_stride_gamma(args.batch_stride_gamma),
batch_stride_beta(args.batch_stride_beta),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
semaphore(static_cast<int *>(workspace)) {
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
{}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
void update(Arguments const &args)
{
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_var = const_cast<void *>(args.ptr_var);
@ -441,22 +401,13 @@ public:
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_var = args.batch_stride_var;
batch_stride_mean = args.batch_stride_mean;
batch_stride_gamma = args.batch_stride_gamma;
batch_stride_beta = args.batch_stride_beta;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
output_op = args.epilogue;
semaphore = static_cast<int *>(workspace);
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
@ -466,12 +417,9 @@ public:
public:
//
// Methods
// Host dispatch API
//
CUTLASS_DEVICE
GemmLayernormMainloopFusion() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
@ -555,12 +503,23 @@ public:
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
public:
return 0;
//
// Device-only API
//
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const &params,
SharedStorage &shared_storage)
{
GemmLayernormMainloopFusion op;
op(params, shared_storage);
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {

View File

@ -41,6 +41,7 @@
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -105,16 +106,12 @@ public:
//
/// Argument structure
struct Arguments {
struct Arguments : UniversalArgumentsBase
{
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
void const * ptr_A_real;
@ -144,17 +141,13 @@ public:
int64_t batch_stride_B_imag;
int64_t batch_stride_C;
int64_t batch_stride_C_imag;
int64_t batch_stride_D;
int64_t batch_stride_D_imag;
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
Arguments() :
ptr_A_real(nullptr),
ptr_A_imag(nullptr),
ptr_B_real(nullptr),
@ -163,7 +156,7 @@ public:
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr)
{ }
{}
/// constructs an arguments structure
Arguments(
@ -194,11 +187,9 @@ public:
int64_t batch_stride_C = 0,
int64_t batch_stride_C_imag = 0,
int64_t batch_stride_D = 0,
int64_t batch_stride_D_imag = 0
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
int64_t batch_stride_D_imag = 0)
:
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_A_real(ptr_A_real),
ptr_A_imag(ptr_A_imag),
@ -222,10 +213,8 @@ public:
batch_stride_B_imag(batch_stride_B_imag),
batch_stride_C(batch_stride_C),
batch_stride_C_imag(batch_stride_C_imag),
batch_stride_D(batch_stride_D),
batch_stride_D_imag(batch_stride_D_imag) {
}
batch_stride_D_imag(batch_stride_D_imag)
{}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
@ -243,16 +232,30 @@ public:
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>
{
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
//
// Data members
//
typename Mma::IteratorA::Params params_A_real;
typename Mma::IteratorA::Params params_A_imag;
typename Mma::IteratorB::Params params_B_real;
@ -264,10 +267,6 @@ public:
typename EpilogueOutputOp::Params output_op;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A_real;
void * ptr_A_imag;
void * ptr_B_real;
@ -278,54 +277,28 @@ public:
void * ptr_D_imag;
int64_t batch_stride_A;
int64_t batch_stride_A_imag;
int64_t batch_stride_B;
int64_t batch_stride_B_imag;
int64_t batch_stride_C;
int64_t batch_stride_A_imag;
int64_t batch_stride_B_imag;
int64_t batch_stride_C_imag;
int64_t batch_stride_D;
int64_t batch_stride_D_imag;
int *semaphore;
//
// Methods
// Host dispatch API
//
CUTLASS_HOST_DEVICE
Params():
batch_count(0),
gemm_k_size(0),
swizzle_log_tile(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A_real(nullptr),
ptr_A_imag(nullptr),
ptr_B_real(nullptr),
ptr_B_imag(nullptr),
ptr_C_real(nullptr),
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr),
batch_stride_A(0),
batch_stride_A_imag(0),
batch_stride_B(0),
batch_stride_B_imag(0),
batch_stride_C(0),
batch_stride_C_imag(0),
batch_stride_D(0),
batch_stride_D_imag(0),
semaphore(nullptr) { }
/// Default constructor
Params() = default;
CUTLASS_HOST_DEVICE
/// Constructor
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
Arguments const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
ParamsBase(args, device_sms, sm_occupancy),
params_A_real(args.lda_real),
params_A_imag(args.lda_imag),
params_B_real(args.ldb_real),
@ -335,9 +308,6 @@ public:
params_D_real(args.ldd_real),
params_D_imag(args.ldd_imag),
output_op(args.epilogue),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A_real(const_cast<void *>(args.ptr_A_real)),
ptr_A_imag(const_cast<void *>(args.ptr_A_imag)),
ptr_B_real(const_cast<void *>(args.ptr_B_real)),
@ -347,21 +317,32 @@ public:
ptr_D_real(args.ptr_D_real),
ptr_D_imag(args.ptr_D_imag),
batch_stride_A(args.batch_stride_A),
batch_stride_A_imag(args.batch_stride_A_imag),
batch_stride_B(args.batch_stride_B),
batch_stride_B_imag(args.batch_stride_B_imag),
batch_stride_C(args.batch_stride_C),
batch_stride_A_imag(args.batch_stride_A_imag),
batch_stride_B_imag(args.batch_stride_B_imag),
batch_stride_C_imag(args.batch_stride_C_imag),
batch_stride_D(args.batch_stride_D),
batch_stride_D_imag(args.batch_stride_D_imag),
semaphore(static_cast<int *>(workspace)) {
batch_stride_D_imag(args.batch_stride_D_imag)
{}
/// Returns the workspace size (in bytes) needed for this problem geometry
size_t get_workspace_size() const
{
size_t workspace_bytes = ParamsBase::get_workspace_size();
if (this->mode == GemmUniversalMode::kGemmSplitKParallel)
{
// Double the size returned by the base class because we need to
// accumulate two ElementC components
workspace_bytes *= 2;
}
return workspace_bytes;
}
void update(
Arguments const &args,
void *workspace = nullptr) {
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
void update(Arguments const &args)
{
ptr_A_real = const_cast<void *>(args.ptr_A_real);
ptr_A_imag = const_cast<void *>(args.ptr_A_imag);
@ -374,21 +355,11 @@ public:
ptr_D_real = const_cast<void *>(args.ptr_D_real);
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
batch_stride_A = args.batch_stride_A;
batch_stride_A_imag = args.batch_stride_A_imag;
batch_stride_B = args.batch_stride_B;
batch_stride_B_imag = args.batch_stride_B_imag;
batch_stride_C = args.batch_stride_C;
batch_stride_C_imag = args.batch_stride_C_imag;
batch_stride_D = args.batch_stride_D;
batch_stride_D_imag = args.batch_stride_D_imag;
output_op = args.epilogue;
semaphore = static_cast<int *>(workspace);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
@ -398,15 +369,12 @@ public:
public:
//
// Methods
// Host dispatch API
//
CUTLASS_DEVICE
GemmPlanarComplex() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(Arguments const &args) {
static Status can_implement(Arguments const &args)
{
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
@ -440,12 +408,23 @@ public:
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
public:
return 0;
//
// Device-only API
//
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const &params,
SharedStorage &shared_storage)
{
GemmPlanarComplex op;
op(params, shared_storage);
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {

View File

@ -41,6 +41,7 @@
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -105,16 +106,12 @@ public:
//
/// Argument structure
struct Arguments {
struct Arguments : UniversalArgumentsBase
{
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
int const *ptr_M;
@ -142,15 +139,11 @@ public:
typename LayoutC::Stride::Index ldd_real;
typename LayoutC::Stride::Index ldd_imag;
int64_t batch_stride_D; // unused
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kArray),
batch_count(1),
ptr_M(nullptr),
ptr_N(nullptr),
ptr_K(nullptr),
@ -161,9 +154,8 @@ public:
ptr_C_real(nullptr),
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr),
batch_stride_D(0)
{ }
ptr_D_imag(nullptr)
{}
/// constructs an arguments structure
Arguments(
@ -188,11 +180,9 @@ public:
typename LayoutC::Stride::Index ldc_real,
typename LayoutC::Stride::Index ldc_imag,
typename LayoutC::Stride::Index ldd_real,
typename LayoutC::Stride::Index ldd_imag
):
mode(GemmUniversalMode::kArray),
problem_size(problem_size),
batch_count(batch_count),
typename LayoutC::Stride::Index ldd_imag)
:
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_M(ptr_M),
ptr_N(ptr_N),
@ -212,10 +202,8 @@ public:
ldc_real(ldc_real),
ldc_imag(ldc_imag),
ldd_real(ldd_real),
ldd_imag(ldd_imag),
batch_stride_D(0) {
}
ldd_imag(ldd_imag)
{}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
@ -232,15 +220,30 @@ public:
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>
{
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
//
// Data members
//
typename Mma::IteratorA::Params params_A_real;
typename Mma::IteratorA::Params params_A_imag;
typename Mma::IteratorB::Params params_B_real;
@ -249,11 +252,9 @@ public:
typename Epilogue::OutputTileIterator::Params params_C_imag;
typename Epilogue::OutputTileIterator::Params params_D_real;
typename Epilogue::OutputTileIterator::Params params_D_imag;
typename EpilogueOutputOp::Params output_op;
int batch_count;
int const *ptr_M;
int const *ptr_N;
int const *ptr_K;
@ -268,35 +269,19 @@ public:
void * const * ptr_D_imag;
//
// Methods
// Host dispatch API
//
CUTLASS_HOST_DEVICE
Params():
batch_count(0),
swizzle_log_tile(0),
ptr_M(nullptr),
ptr_N(nullptr),
ptr_K(nullptr),
ptr_A_real(nullptr),
ptr_A_imag(nullptr),
ptr_B_real(nullptr),
ptr_B_imag(nullptr),
ptr_C_real(nullptr),
ptr_C_imag(nullptr),
ptr_D_real(nullptr),
ptr_D_imag(nullptr) { }
/// Default constructor
Params() = default;
CUTLASS_HOST_DEVICE
/// Constructor
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size = 0, // ignored
void *workspace = nullptr // ignored
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
Arguments const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
ParamsBase(args, device_sms, sm_occupancy),
ptr_M(args.ptr_M),
ptr_N(args.ptr_N),
ptr_K(args.ptr_K),
@ -309,7 +294,6 @@ public:
params_D_real(args.ldd_real),
params_D_imag(args.ldd_imag),
output_op(args.epilogue),
batch_count(args.batch_count),
ptr_A_real(args.ptr_A_real),
ptr_A_imag(args.ptr_A_imag),
ptr_B_real(args.ptr_B_real),
@ -317,14 +301,13 @@ public:
ptr_C_real(args.ptr_C_real),
ptr_C_imag(args.ptr_C_imag),
ptr_D_real(args.ptr_D_real),
ptr_D_imag(args.ptr_D_imag) {
}
void update(
Arguments const &args,
void *workspace = nullptr) {
ptr_D_imag(args.ptr_D_imag)
{}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
void update(Arguments const &args)
{
ptr_M = args.ptr_M;
ptr_N = args.ptr_N;
ptr_K = args.ptr_K;
@ -345,6 +328,7 @@ public:
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
@ -354,12 +338,9 @@ public:
public:
//
// Methods
// Host dispatch API
//
CUTLASS_DEVICE
GemmPlanarComplexArray() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(Arguments const &args) {
@ -396,12 +377,24 @@ public:
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
public:
//
// Device-only API
//
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const &params,
SharedStorage &shared_storage)
{
GemmPlanarComplexArray op;
op(params, shared_storage);
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {

View File

@ -37,12 +37,12 @@
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/params_universal_base.h"
#include "cutlass/trace.h"
@ -55,7 +55,7 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
@ -101,16 +101,12 @@ public:
//
/// Argument structure
struct Arguments {
struct Arguments : UniversalArgumentsBase
{
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
typename EpilogueOutputOp::Params epilogue;
void const * ptr_A;
@ -121,7 +117,6 @@ public:
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
typename LayoutA::Stride stride_a;
typename LayoutB::Stride stride_b;
@ -140,14 +135,13 @@ public:
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1),
Arguments():
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr) {}
ptr_scatter_D_indices(nullptr)
{}
/// constructs an arguments structure
Arguments(
@ -169,23 +163,22 @@ public:
typename LayoutC::Stride stride_d,
int const *ptr_gather_A_indices = nullptr,
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
int const *ptr_scatter_D_indices = nullptr)
:
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
ptr_scatter_D_indices(ptr_scatter_D_indices)
{
lda = 0;
ldb = 0;
ldc = 0;
ldd = 0;
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
}
/// constructs an arguments structure
Arguments(
@ -209,26 +202,26 @@ public:
int const *ptr_gather_B_indices = nullptr,
int const *ptr_scatter_D_indices = nullptr
):
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
epilogue(epilogue),
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
ptr_scatter_D_indices(ptr_scatter_D_indices) {
ptr_scatter_D_indices(ptr_scatter_D_indices)
{
stride_a = make_Coord(lda);
stride_b = make_Coord(ldb);
stride_c = make_Coord(ldc);
stride_d = make_Coord(ldd);
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
}
/// Returns arguments for the transposed problem
Arguments transposed_problem() const {
Arguments transposed_problem() const
{
Arguments args(*this);
std::swap(args.problem_size.m(), args.problem_size.n());
std::swap(args.ptr_A, args.ptr_B);
std::swap(args.lda, args.ldb);
@ -240,27 +233,36 @@ public:
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
struct Params : UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>
{
using ParamsBase = UniversalParamsBase<
ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
//
// Data members
//
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename EpilogueOutputOp::Params output_op;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
typename EpilogueOutputOp::Params output_op;
void * ptr_A;
void * ptr_B;
@ -270,59 +272,30 @@ public:
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_C;
int64_t batch_stride_D;
int * ptr_gather_A_indices;
int * ptr_gather_B_indices;
int * ptr_scatter_D_indices;
int *semaphore;
//
// Methods
// Host dispatch API
//
CUTLASS_HOST_DEVICE
Params():
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_C(0),
batch_stride_D(0),
ptr_gather_A_indices(nullptr),
ptr_gather_B_indices(nullptr),
ptr_scatter_D_indices(nullptr),
semaphore(nullptr) { }
/// Default constructor
Params() = default;
CUTLASS_HOST_DEVICE
/// Constructor
Params(
Arguments const &args,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
int gemm_k_size,
void *workspace = nullptr
):
problem_size(args.problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
Arguments const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
ParamsBase(args, device_sms, sm_occupancy),
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
output_op(args.epilogue),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(gemm_k_size),
ptr_A(const_cast<void *>(args.ptr_A)),
ptr_B(const_cast<void *>(args.ptr_B)),
ptr_C(const_cast<void *>(args.ptr_C)),
@ -330,19 +303,18 @@ public:
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D),
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
semaphore(static_cast<int *>(workspace)) {
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
{}
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr) {
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
void update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
// Update input/output pointers
ptr_A = const_cast<void *>(args.ptr_A);
ptr_B = const_cast<void *>(args.ptr_B);
ptr_C = const_cast<void *>(args.ptr_C);
@ -352,37 +324,28 @@ public:
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_D = args.batch_stride_D;
output_op = args.epilogue;
semaphore = static_cast<int *>(workspace);
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
// Host dispatch API
//
CUTLASS_DEVICE
GemmUniversal() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) {
cutlass::gemm::GemmCoord const & problem_size)
{
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
static int const kAlignmentA = (platform::is_same<LayoutA,
@ -462,15 +425,30 @@ public:
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args,
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
public:
//
// Device-only API
//
// Factory invocation
CUTLASS_DEVICE
static void invoke(
Params const &params,
SharedStorage &shared_storage)
{
GemmUniversal op;
op(params, shared_storage);
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
void operator()(
Params const &params,
SharedStorage &shared_storage)
{
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
@ -677,7 +655,7 @@ public:
// Release the semaphore
//
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More