releaase 2.11 (#703)
This commit is contained in:
@ -80,9 +80,9 @@ public:
|
||||
typedef value_type *pointer;
|
||||
typedef value_type const * const_pointer;
|
||||
|
||||
using ArrayType = Array<T, N>;
|
||||
using reference = typename ArrayType::reference;
|
||||
using const_reference = typename ArrayType::const_reference;
|
||||
using Array = Array<T, N>;
|
||||
using reference = typename Array::reference;
|
||||
using const_reference = typename Array::const_reference;
|
||||
|
||||
public:
|
||||
|
||||
|
||||
@ -85,6 +85,10 @@ struct Sm86 {
|
||||
static int const kMinComputeCapability = 86;
|
||||
};
|
||||
|
||||
struct Sm90 {
|
||||
static int const kMinComputeCapability = 90;
|
||||
};
|
||||
|
||||
/// Triggers a breakpoint on the device
|
||||
CUTLASS_DEVICE
|
||||
void device_breakpoint() {
|
||||
|
||||
@ -451,7 +451,7 @@ template <>
|
||||
CUTLASS_DEVICE
|
||||
void shared_store<16>(uint32_t ptr, void const *src) {
|
||||
uint4 const *dst_u128 = reinterpret_cast<uint4 const *>(src);
|
||||
asm volatile("ld.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
: :
|
||||
"r"(ptr),
|
||||
"r"(dst_u128->x),
|
||||
|
||||
@ -223,4 +223,6 @@ struct SparseMma;
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sparse_sm80.h"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1065,7 +1065,7 @@ struct Mma<
|
||||
int const *C = reinterpret_cast<int const *>(&c);
|
||||
int *D = reinterpret_cast<int *>(&d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
asm volatile("_mma.m8n8k32.row.col.u4.s4.sat {%0,%1}, %2, %3, {%4,%5};\n"
|
||||
: "=r"(D[0]), "=r"(D[1])
|
||||
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
|
||||
|
||||
@ -1247,7 +1247,8 @@ struct Mma<
|
||||
) const {
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED))
|
||||
using WmmaFragmentA = nvcuda::wmma::fragment<
|
||||
nvcuda::wmma::matrix_a,
|
||||
Shape::kM,
|
||||
@ -1279,6 +1280,7 @@ struct Mma<
|
||||
|
||||
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
|
||||
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(a);
|
||||
@ -1289,14 +1291,7 @@ struct Mma<
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
#else
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_UNUSED(d);
|
||||
assert(0);
|
||||
#endif
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -2156,6 +2156,7 @@ struct Mma<
|
||||
|
||||
int const *C = reinterpret_cast<int const *>(&c);
|
||||
int *D = reinterpret_cast<int *>(&d);
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, "
|
||||
"{%4,%5,%6,%7}, "
|
||||
|
||||
131
include/cutlass/arch/mma_sm90.h
Normal file
131
include/cutlass/arch/mma_sm90.h
Normal file
@ -0,0 +1,131 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Matrix multiply
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "mma.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8))
|
||||
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
#define CUTLASS_ARCH_MMA_SM90_ENABLED
|
||||
#endif
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Matrix Multiply-Add 16x8x4 fp64
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Matrix multiply-add operation: F64 = F64 * F64 + F64
|
||||
template <>
|
||||
struct Mma<
|
||||
gemm::GemmShape<16,8,4>,
|
||||
32,
|
||||
double,
|
||||
layout::RowMajor,
|
||||
double,
|
||||
layout::ColumnMajor,
|
||||
double,
|
||||
layout::RowMajor,
|
||||
OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<16,8,4>;
|
||||
|
||||
using ElementA = double;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using FragmentA = Array<double, 2>;
|
||||
|
||||
using ElementB = double;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using FragmentB = Array<double, 1>;
|
||||
|
||||
using ElementC = double;
|
||||
using LayoutC = layout::RowMajor;
|
||||
using FragmentC = Array<double, 4>;
|
||||
|
||||
using Operator = OpMultiplyAdd;
|
||||
|
||||
using ArchTag = arch::Sm90;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
|
||||
FragmentC const &c) const {
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED)
|
||||
|
||||
double const *A = reinterpret_cast<double const *>(&a);
|
||||
double const *B = reinterpret_cast<double const *>(&b);
|
||||
|
||||
double const *C = reinterpret_cast<double const *>(&c);
|
||||
double *D = reinterpret_cast<double *>(&d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3])
|
||||
: "d"(A[0]), "d"(A[1]),
|
||||
"d"(B[0]),
|
||||
"d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3]));
|
||||
|
||||
#else
|
||||
|
||||
CUTLASS_UNUSED(d);
|
||||
CUTLASS_UNUSED(a);
|
||||
CUTLASS_UNUSED(b);
|
||||
CUTLASS_UNUSED(c);
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
201
include/cutlass/barrier.h
Normal file
201
include/cutlass/barrier.h
Normal file
@ -0,0 +1,201 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implementation of a CTA-wide barrier for inter-CTA synchronization.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// CTA-wide semaphore for inter-CTA synchronization.
|
||||
struct Barrier
|
||||
{
|
||||
|
||||
public:
|
||||
|
||||
/// Flag type
|
||||
using T = int;
|
||||
|
||||
/// Initial flag value
|
||||
static const T INIT = 0;
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
/// Load flag, as a strong operation (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static int ld_strong(int *ptr)
|
||||
{
|
||||
int state = 0;
|
||||
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
#else
|
||||
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/// Store flag, as a strong operation (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static void st_strong(int *ptr, int val)
|
||||
{
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#else
|
||||
asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
}
|
||||
|
||||
|
||||
/// Reduce into flag, with release pattern (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static void red_release(int *ptr, int val)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#else
|
||||
__threadfence();
|
||||
atomicAdd(ptr, val);
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
|
||||
CUTLASS_DEVICE
|
||||
static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(ld_strong(flag_ptr) < count) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Uses thread[0] to wait for at least the specified count of signals on the given flag counter
|
||||
CUTLASS_DEVICE
|
||||
static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(ld_strong(flag_ptr) != val) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Uses thread[0] to wait for the specified count of signals on the given flag counter
|
||||
CUTLASS_DEVICE
|
||||
static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T *flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(atomicCAS(flag_ptr, val, 0) != val) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Increment the arrival count for a flag
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (thread_idx == 0) {
|
||||
red_release(flag_ptr, 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/// Increment the arrival counts for a range of flags
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1)
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
int flag_idx = first_flag_idx + thread_idx;
|
||||
T* flag_ptr = reinterpret_cast<T*>(lock_ptr) + flag_idx;
|
||||
|
||||
// Barrier to make sure all other threads in block have written their data
|
||||
__syncthreads();
|
||||
|
||||
// Select threads increment their flags
|
||||
if (thread_idx < count) {
|
||||
red_release(flag_ptr, 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -35,7 +35,9 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include "cutlass/floating_point_nvrtc.h"
|
||||
#else
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <cstdint>
|
||||
@ -71,8 +73,7 @@ struct alignas(2) bfloat16_t {
|
||||
}
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t() : storage(0) { }
|
||||
bfloat16_t() = default;
|
||||
|
||||
/// Floating-point conversion - round toward nearest
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
|
||||
259
include/cutlass/block_striped.h
Normal file
259
include/cutlass/block_striped.h
Normal file
@ -0,0 +1,259 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable,
|
||||
statically-sized array types to global memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/wmma_array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/complex.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// AccessWidth
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit
|
||||
template <
|
||||
typename T,
|
||||
int Limit>
|
||||
struct AccessWidth
|
||||
{
|
||||
// Inductive case
|
||||
template <
|
||||
int ObjectBytes, /// Size of T in bytes
|
||||
int AlignBytes, /// Template induction variable
|
||||
bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes
|
||||
((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))>
|
||||
struct Detail
|
||||
{
|
||||
static const int value = Detail<ObjectBytes, AlignBytes * 2>::value;
|
||||
};
|
||||
|
||||
// Base case (ObjectBytes is not an even multiple of AlignBytes)
|
||||
template <
|
||||
int ObjectBytes, /// Size of T in bytes
|
||||
int AlignBytes> /// Template induction variable
|
||||
struct Detail<ObjectBytes, AlignBytes, false>
|
||||
{
|
||||
static const int value = AlignBytes / 2;
|
||||
};
|
||||
|
||||
/// The maximal power-of-two that evenly divides the size of T
|
||||
static const int value = Detail<
|
||||
(int) sizeof(T),
|
||||
1>::value;
|
||||
};
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// StripedAccessType
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// ReinterpretCast type for striping a trivially-copyable type in global memory
|
||||
/// (Default specialization. Striping granularity is type T.)
|
||||
template <
|
||||
typename T, /// Data type
|
||||
int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures)
|
||||
AccessWidth<T, 16>::value>
|
||||
struct alignas(TransferBytes) StripedAccessType : public T
|
||||
{};
|
||||
|
||||
|
||||
/// ReinterpretCast type for striping a trivially-copyable type in global memory
|
||||
/// (Specialization for cutlass::Array<T>. Striping granularity is a multiple of T.)
|
||||
template <
|
||||
typename T, /// Array element type
|
||||
int N, /// Number of elements in array
|
||||
bool RegisterSized, /// T is register-sized
|
||||
int TransferBytes> /// Data access width
|
||||
struct StripedAccessType<
|
||||
Array<T, N, RegisterSized>,
|
||||
TransferBytes>
|
||||
: public AlignedArray<
|
||||
T, // Element type of StripedAccessType
|
||||
__NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType
|
||||
TransferBytes> // Alignment of StripedAccessType
|
||||
{};
|
||||
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
/// ReinterpretCast type for striping a trivially-copyable type in global memory
|
||||
/// (Specialization for cutlass::WmmaFragmentArray<T>. Striping granularity is a multiple of T.)
|
||||
template<
|
||||
typename Use,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
typename ElementT,
|
||||
typename Layout,
|
||||
int kFragments,
|
||||
int TransferBytes>
|
||||
struct StripedAccessType<
|
||||
WmmaFragmentArray<nvcuda::wmma::fragment<Use, m, n, k, ElementT, Layout>, kFragments>,
|
||||
TransferBytes>
|
||||
: public AlignedArray<
|
||||
ElementT,
|
||||
__NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)),
|
||||
TransferBytes>
|
||||
{};
|
||||
|
||||
#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// BlockStriped
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Utility for performing block-striped access (load, store) of trivially-copyable,
|
||||
/// statically-sized array types to global memory
|
||||
template <
|
||||
int BlockThreads,
|
||||
typename ArrayT,
|
||||
typename T,
|
||||
typename AccessT = StripedAccessType<ArrayT> >
|
||||
struct BlockStriped
|
||||
{
|
||||
/// Number of striped accesses
|
||||
static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT));
|
||||
static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type");
|
||||
|
||||
/// Load
|
||||
CUTLASS_DEVICE
|
||||
static void load(ArrayT &data, T *ptr, int thread_idx)
|
||||
{
|
||||
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
|
||||
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kStripes; ++i) {
|
||||
access_data[i] = access_input[(BlockThreads * i) + thread_idx];
|
||||
}
|
||||
}
|
||||
|
||||
/// Load & Add
|
||||
CUTLASS_DEVICE
|
||||
static void load_add(ArrayT &data, T *ptr, int thread_idx)
|
||||
{
|
||||
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
|
||||
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
|
||||
|
||||
plus<AccessT> add;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kStripes; ++i)
|
||||
{
|
||||
access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Store
|
||||
CUTLASS_DEVICE
|
||||
static void store(T *ptr, const ArrayT &data, int thread_idx)
|
||||
{
|
||||
AccessT *access_output = reinterpret_cast<AccessT*>(ptr);
|
||||
const AccessT *access_data = reinterpret_cast<const AccessT*>(&data);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kStripes; ++i) {
|
||||
access_output[(BlockThreads * i) + thread_idx] = access_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// BlockStripedReduce
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
|
||||
/// statically-sized array types to global memory.
|
||||
/// (Default specialization)
|
||||
template <
|
||||
int BlockThreads,
|
||||
typename ArrayT,
|
||||
typename T>
|
||||
struct BlockStripedReduce : BlockStriped<BlockThreads, ArrayT, T, T>
|
||||
{
|
||||
/// Reduce
|
||||
CUTLASS_DEVICE
|
||||
static void reduce(T *ptr, const ArrayT &data, int thread_idx)
|
||||
{
|
||||
cutlass::red<T> reduce;
|
||||
const T *access_data = reinterpret_cast<const T*>(&data);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < BlockStripedReduce::kStripes; ++i) {
|
||||
reduce(ptr + (BlockThreads * i) + thread_idx, access_data[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
|
||||
/// statically-sized array types to global memory.
|
||||
/// (Specialization for half_t. Uses half2 vectorized-reduction.)
|
||||
template <
|
||||
int BlockThreads,
|
||||
typename ArrayT>
|
||||
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> : BlockStriped<BlockThreads, ArrayT, half_t, half2>
|
||||
{
|
||||
static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
|
||||
|
||||
/// Reduce
|
||||
CUTLASS_DEVICE
|
||||
static void reduce(half_t *ptr, const ArrayT &data, int thread_idx)
|
||||
{
|
||||
cutlass::red<half2> reduce;
|
||||
half2 *access_output = reinterpret_cast<half2*>(ptr);
|
||||
const half2 *access_data = reinterpret_cast<const half2*>(&data);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < BlockStripedReduce::kStripes; ++i)
|
||||
{
|
||||
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -32,6 +32,8 @@
|
||||
|
||||
#include <cuComplex.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
@ -39,6 +41,7 @@
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/real.h"
|
||||
|
||||
@ -53,8 +56,10 @@
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Enumeraed type describing a transformation on a complex value.
|
||||
enum class ComplexTransform {
|
||||
kNone,
|
||||
@ -147,15 +152,18 @@ class complex
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex(T r = T(0)) : _real(r), _imag(T(0)) {}
|
||||
/// Default constructor
|
||||
complex() = default;
|
||||
|
||||
/// Constructor
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex(T r) : _real(r), _imag(T(0)) {}
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex(T r, T i) : _real(r), _imag(i) {}
|
||||
//
|
||||
/// Constructor
|
||||
|
||||
/// Constructor
|
||||
template<typename A>
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex(complex<A> const &z) : _real(static_cast<T>(z.real())), _imag(static_cast<T>(z.imag())) {}
|
||||
@ -197,6 +205,24 @@ class complex
|
||||
return complex<T>(this->real() + rhs.real(), this->imag() + rhs.imag());
|
||||
}
|
||||
|
||||
/// Reduction into memory address. Components may update out of order.
|
||||
template <typename OtherT>
|
||||
CUTLASS_DEVICE void red(complex<OtherT> *ptr) const {
|
||||
static_assert(platform::is_same<T, OtherT>::value, "Component type must match");
|
||||
cutlass::red<T> reduce;
|
||||
reduce(&ptr->_real, _real);
|
||||
reduce(&ptr->_imag, _imag);
|
||||
}
|
||||
|
||||
/// Reduction into memory address. Components may update out of order. (Half specialization)
|
||||
CUTLASS_DEVICE void red(complex<half_t> *ptr) const {
|
||||
static_assert(platform::is_same<T, half_t>::value, "Component type must match");
|
||||
half2 *h2_ptr = reinterpret_cast<half2*>(ptr);
|
||||
half2 h2_data = reinterpret_cast<half2&>(*this);
|
||||
cutlass::red<half2> reduce;
|
||||
reduce(h2_ptr, h2_data);
|
||||
}
|
||||
|
||||
/// Subtraction
|
||||
template <typename A>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator-(complex<A> const &rhs) const {
|
||||
@ -506,13 +532,14 @@ CUTLASS_HOST_DEVICE bool operator<(complex<T> const &lhs, complex<T> const &rhs)
|
||||
|
||||
/// Partial specialization for complex-valued type.
|
||||
template <typename T>
|
||||
struct RealType< complex<T> > {
|
||||
struct RealType< complex<T> >
|
||||
{
|
||||
using Type = T;
|
||||
|
||||
/// Number of elements
|
||||
static int const kExtent = 2;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
CUTLASS_HOST_DEVICE
|
||||
static complex<T> from_real(double x) {
|
||||
return complex<T>(static_cast<T>(x));
|
||||
}
|
||||
@ -550,6 +577,127 @@ struct is_complex<complex<T>> {
|
||||
static bool const value = true;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// functional.h numeric specializations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Squares with optional conversion
|
||||
template <typename T, typename Output>
|
||||
struct magnitude_squared<complex<T>, Output> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(complex<T> lhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y_r = Output(lhs.real());
|
||||
Output y_i = Output(lhs.imag());
|
||||
|
||||
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add
|
||||
template <typename T>
|
||||
struct multiply_add<complex<T>, complex<T>, complex<T>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex<T> operator()(
|
||||
complex<T> const &a,
|
||||
complex<T> const &b,
|
||||
complex<T> const &c) const {
|
||||
|
||||
T real = c.real();
|
||||
T imag = c.imag();
|
||||
|
||||
real += a.real() * b.real();
|
||||
real += -a.imag() * b.imag();
|
||||
imag += a.real() * b.imag();
|
||||
imag += a.imag () * b.real();
|
||||
|
||||
return complex<T>{
|
||||
real,
|
||||
imag
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add
|
||||
template <typename T>
|
||||
struct multiply_add<complex<T>, T, complex<T>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex<T> operator()(
|
||||
complex<T> const &a,
|
||||
T const &b,
|
||||
complex<T> const &c) const {
|
||||
|
||||
T real = c.real();
|
||||
T imag = c.imag();
|
||||
|
||||
real += a.real() * b;
|
||||
imag += a.imag () * b;
|
||||
|
||||
return complex<T>{
|
||||
real,
|
||||
imag
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Fused multiply-add
|
||||
template <typename T>
|
||||
struct multiply_add<T, complex<T>, complex<T>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex<T> operator()(
|
||||
T const &a,
|
||||
complex<T> const &b,
|
||||
complex<T> const &c) const {
|
||||
|
||||
T real = c.real();
|
||||
T imag = c.imag();
|
||||
|
||||
real += a * b.real();
|
||||
imag += a * b.imag();
|
||||
|
||||
return complex<T>{
|
||||
real,
|
||||
imag
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Conjugate
|
||||
template <typename T>
|
||||
struct conjugate<complex<T>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex<T> operator()(complex<T> const &a) const {
|
||||
return conj(a);
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes the square of a difference with optional conversion
|
||||
template <typename T, typename Output>
|
||||
struct magnitude_squared_difference<complex<T>, Output> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Output operator()(complex<T> lhs, complex<T> rhs) const {
|
||||
multiplies<Output> mul_op;
|
||||
|
||||
Output y_r = Output(lhs.real()) - Output(rhs.real());
|
||||
Output y_i = Output(lhs.imag()) - Output(rhs.imag());
|
||||
|
||||
return mul_op(y_r, y_r) + mul_op(y_i, y_i);
|
||||
}
|
||||
};
|
||||
|
||||
/// Reduces value into the data pointed to by ptr (complex<T> specialization)
|
||||
template <typename T>
|
||||
struct red<complex<T>> {
|
||||
CUTLASS_DEVICE
|
||||
void operator()(complex<T> *ptr, const complex<T> &data)
|
||||
{
|
||||
data.red(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -247,7 +247,7 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::Tensor4DCoord filter_extent() const {
|
||||
|
||||
return cutlass::Tensor4DCoord ({K, R, S, C});
|
||||
return cutlass::Tensor4DCoord ({K, R, S, C / groups});
|
||||
}
|
||||
|
||||
/// Returns output extent as Tensor4DCoord
|
||||
@ -336,7 +336,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size(
|
||||
return gemm::GemmCoord(
|
||||
problem_size.N * problem_size.P * problem_size.Q,
|
||||
problem_size.K,
|
||||
problem_size.R * problem_size.S * problem_size.C
|
||||
problem_size.R * problem_size.S * problem_size.C / problem_size.groups
|
||||
);
|
||||
case Operator::kDgrad:
|
||||
return gemm::GemmCoord(
|
||||
@ -451,6 +451,18 @@ int implicit_gemm_k_iterations(
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (algorithm == IteratorAlgorithm::kOptimized) {
|
||||
// Current optimized iterator only support GroupMode::kSingleGroup
|
||||
if (group_mode == GroupMode::kSingleGroup) {
|
||||
switch (conv_operator) {
|
||||
case Operator::kFprop:
|
||||
iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -459,6 +471,25 @@ int implicit_gemm_k_iterations(
|
||||
}
|
||||
|
||||
|
||||
template <int N = 1, int Output_P = 1, int Output_Q = 1>
|
||||
CUTLASS_HOST_DEVICE
|
||||
int depthwise_gemm_k_iterations(
|
||||
Operator conv_operator,
|
||||
int threadblock_K,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
|
||||
GroupMode group_mode = GroupMode::kNone,
|
||||
int threadblock_N = 0) {
|
||||
|
||||
int n = problem_size.N;
|
||||
int p = (problem_size.P + Output_P - 1) / Output_P;
|
||||
int q = (problem_size.Q + Output_Q - 1) / Output_Q;
|
||||
|
||||
int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
|
||||
return iterations;
|
||||
}
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int implicit_gemm_k_iterations_per_channel(
|
||||
Operator conv_operator,
|
||||
|
||||
@ -100,14 +100,16 @@ enum class IteratorAlgorithm {
|
||||
kAnalytic, ///< functionally correct in all cases but lower performance
|
||||
kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad
|
||||
kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize)
|
||||
kFewChannels ///< Analytic algorithm optimized for few channels (C divisible by AccessSize)
|
||||
kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize)
|
||||
kFixedStrideDilation ///< Optimized for fixed stride and dilation
|
||||
};
|
||||
|
||||
/// Distinguishes among partial specializations that accelerate certain problems where convolution
|
||||
/// stride is unit.
|
||||
enum class StrideSupport {
|
||||
kStrided, ///< arbitrary convolution stride
|
||||
kUnity ///< unit convolution stride
|
||||
kUnity, ///< unit convolution stride
|
||||
kFixed ///< fixed convolution stride
|
||||
};
|
||||
|
||||
/// Identifies split-K mode
|
||||
@ -125,6 +127,38 @@ enum class GroupMode {
|
||||
kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups)
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Shape of a tensor
|
||||
template <
|
||||
int N = 1,
|
||||
int H = 1,
|
||||
int W = 1,
|
||||
int C = 1
|
||||
>
|
||||
struct TensorNHWCShape {
|
||||
static int const kN = N;
|
||||
static int const kH = H;
|
||||
static int const kW = W;
|
||||
static int const kC = C;
|
||||
|
||||
static int const kHW = H * W;
|
||||
static int const kNHW = N * kHW;
|
||||
static int const kNHWC = N * H * W * C;
|
||||
|
||||
static int const kCount = kNHWC;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Returns a Coord object
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<4> toCoord() {
|
||||
return make_Coord(kN, kH, kW, kC);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace conv
|
||||
|
||||
269
include/cutlass/conv/device/direct_convolution.h
Normal file
269
include/cutlass/conv/device/direct_convolution.h
Normal file
@ -0,0 +1,269 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Template for device-level Depthwise Convolution
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename DirectConvolutionKernel_>
|
||||
class DirectConvolution {
|
||||
public:
|
||||
|
||||
using UnderlyingKernel = DirectConvolutionKernel_;
|
||||
|
||||
using ElementA = typename UnderlyingKernel::ElementA;
|
||||
using LayoutA = typename UnderlyingKernel::LayoutA;
|
||||
using ElementB = typename UnderlyingKernel::ElementB;
|
||||
using LayoutB = typename UnderlyingKernel::LayoutB;
|
||||
using ElementC = typename UnderlyingKernel::ElementC;
|
||||
using LayoutC = typename UnderlyingKernel::LayoutC;
|
||||
using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
|
||||
using ElementCompute = typename UnderlyingKernel::ElementCompute;
|
||||
using OperatorClass = typename UnderlyingKernel::OperatorClass;
|
||||
using ArchTag = typename UnderlyingKernel::ArchTag;
|
||||
using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
|
||||
using WarpShape = typename UnderlyingKernel::WarpShape;
|
||||
using InstructionShape = typename UnderlyingKernel::InstructionShape;
|
||||
using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
|
||||
using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
|
||||
static int const kStages = UnderlyingKernel::kStages;
|
||||
static int const kConvDim = UnderlyingKernel::kConvDim;
|
||||
using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
|
||||
using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
|
||||
using MathOperator = typename UnderlyingKernel::MathOperator;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
|
||||
static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
|
||||
static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
|
||||
|
||||
static int const kWarpCount =
|
||||
(ThreadblockShape::kM / WarpShape::kM) *
|
||||
(ThreadblockShape::kN / WarpShape::kN) *
|
||||
(ThreadblockShape::kK / WarpShape::kK);
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename UnderlyingKernel::Arguments;
|
||||
|
||||
using ReorderKernel = typename UnderlyingKernel::ReorderKernel;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename UnderlyingKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs Implicit GEMM
|
||||
DirectConvolution() { }
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
// dispatch to iterators
|
||||
Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (kGroupMode != conv::GroupMode::kDepthwise) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// C and K should be multiple of groups
|
||||
if (args.problem_size.K != args.problem_size.groups &&
|
||||
args.problem_size.C != args.problem_size.groups) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
|
||||
static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
||||
if (args.problem_size.K % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
} else if (kConvolutionalOperator == conv::Operator::kDgrad) {
|
||||
if (args.problem_size.C % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
} else if (kConvolutionalOperator == conv::Operator::kWgrad) {
|
||||
if (args.problem_size.C % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(
|
||||
threadblock_swizzle.get_tiled_shape(
|
||||
kConvolutionalOperator,
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices));
|
||||
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
grid.z <= std::numeric_limits<uint16_t>::max())) {
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
// initialize the params structure from the arguments
|
||||
params_ = typename UnderlyingKernel::Params(
|
||||
args,
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
// update the params structure from the arguments
|
||||
params_.ptr_A = args.ref_A.data();
|
||||
params_.ptr_B = args.ref_B.data();
|
||||
params_.ptr_C = args.ref_C.data();
|
||||
params_.ptr_D = args.ref_D.data();
|
||||
params_.output_op = args.output_op;
|
||||
params_.ptr_reordered_B = args.ref_reordered_B.data();;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
// Launch reorder kernel
|
||||
if (params_.ptr_reordered_B != nullptr) {
|
||||
dim3 grid = ReorderKernel::get_grid_shape(params_);
|
||||
dim3 block = ReorderKernel::get_block_shape();
|
||||
|
||||
cutlass::Kernel<ReorderKernel><<<grid, block, 0, stream>>>(params_);
|
||||
}
|
||||
|
||||
// Launch main kernel
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(32 * kWarpCount, 1, 1);
|
||||
|
||||
// Dynamic SMEM size based on input params.
|
||||
int smem_size = int(params_.get_smem_size());
|
||||
|
||||
// Make sure we can use that much shared memory.
|
||||
cudaError_t status =
|
||||
cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
if (status != cudaSuccess)
|
||||
return Status::kErrorInternal;
|
||||
|
||||
|
||||
cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
int get_smem_size() { return int(params_.get_smem_size()); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -52,33 +52,33 @@ template<typename ImplicitGemmKernel_>
|
||||
class ImplicitGemmConvolution {
|
||||
public:
|
||||
|
||||
using ImplicitGemmKernel = ImplicitGemmKernel_;
|
||||
using UnderlyingKernel = ImplicitGemmKernel_;
|
||||
|
||||
using ElementA = typename ImplicitGemmKernel::ElementA;
|
||||
using LayoutA = typename ImplicitGemmKernel::LayoutA;
|
||||
using ElementB = typename ImplicitGemmKernel::ElementB;
|
||||
using LayoutB = typename ImplicitGemmKernel::LayoutB;
|
||||
using ElementC = typename ImplicitGemmKernel::ElementC;
|
||||
using LayoutC = typename ImplicitGemmKernel::LayoutC;
|
||||
using ElementAccumulator = typename ImplicitGemmKernel::ElementAccumulator;
|
||||
using ElementCompute = typename ImplicitGemmKernel::ElementCompute;
|
||||
using OperatorClass = typename ImplicitGemmKernel::OperatorClass;
|
||||
using ArchTag = typename ImplicitGemmKernel::ArchTag;
|
||||
using ThreadblockShape = typename ImplicitGemmKernel::ThreadblockShape;
|
||||
using WarpShape = typename ImplicitGemmKernel::WarpShape;
|
||||
using InstructionShape = typename ImplicitGemmKernel::InstructionShape;
|
||||
using ThreadblockSwizzle = typename ImplicitGemmKernel::ThreadblockSwizzle;
|
||||
using EpilogueOutputOp = typename ImplicitGemmKernel::EpilogueOutputOp;
|
||||
static int const kStages = ImplicitGemmKernel::kStages;
|
||||
static int const kConvDim = ImplicitGemmKernel::kConvDim;
|
||||
using WarpMmaOperator = typename ImplicitGemmKernel::WarpMmaOperator;
|
||||
using ArchMmaOperator = typename ImplicitGemmKernel::ArchMmaOperator;
|
||||
using MathOperator = typename ImplicitGemmKernel::MathOperator;
|
||||
using ElementA = typename UnderlyingKernel::ElementA;
|
||||
using LayoutA = typename UnderlyingKernel::LayoutA;
|
||||
using ElementB = typename UnderlyingKernel::ElementB;
|
||||
using LayoutB = typename UnderlyingKernel::LayoutB;
|
||||
using ElementC = typename UnderlyingKernel::ElementC;
|
||||
using LayoutC = typename UnderlyingKernel::LayoutC;
|
||||
using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
|
||||
using ElementCompute = typename UnderlyingKernel::ElementCompute;
|
||||
using OperatorClass = typename UnderlyingKernel::OperatorClass;
|
||||
using ArchTag = typename UnderlyingKernel::ArchTag;
|
||||
using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
|
||||
using WarpShape = typename UnderlyingKernel::WarpShape;
|
||||
using InstructionShape = typename UnderlyingKernel::InstructionShape;
|
||||
using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
|
||||
using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
|
||||
static int const kStages = UnderlyingKernel::kStages;
|
||||
static int const kConvDim = UnderlyingKernel::kConvDim;
|
||||
using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
|
||||
using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
|
||||
using MathOperator = typename UnderlyingKernel::MathOperator;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm;
|
||||
static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport;
|
||||
static cutlass::conv::GroupMode const kGroupMode = ImplicitGemmKernel::kGroupMode;
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
|
||||
static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
|
||||
static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
|
||||
static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
|
||||
|
||||
static int const kWarpCount =
|
||||
(ThreadblockShape::kM / WarpShape::kM) *
|
||||
@ -86,12 +86,12 @@ public:
|
||||
(ThreadblockShape::kK / WarpShape::kK);
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename ImplicitGemmKernel::Arguments;
|
||||
using Arguments = typename UnderlyingKernel::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename ImplicitGemmKernel::Params params_;
|
||||
typename UnderlyingKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
@ -102,12 +102,12 @@ public:
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
// dispatch to iterators
|
||||
Status status = ImplicitGemmKernel::Mma::IteratorA::can_implement(args.problem_size);
|
||||
Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = ImplicitGemmKernel::Mma::IteratorB::can_implement(args.problem_size);
|
||||
status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size);
|
||||
if (Status::kSuccess != status) {
|
||||
return status;
|
||||
}
|
||||
@ -138,9 +138,15 @@ public:
|
||||
if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// current optimized iterator algo only supports SingleGroup mode
|
||||
if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized &&
|
||||
kGroupMode != conv::GroupMode::kSingleGroup) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
if (kConvolutionalOperator == conv::Operator::kFprop) {
|
||||
if (args.problem_size.K % kAlignmentC)
|
||||
return Status::kErrorMisalignedOperand;
|
||||
@ -249,15 +255,15 @@ public:
|
||||
}
|
||||
|
||||
// initialize the params structure from the arguments
|
||||
params_ = typename ImplicitGemmKernel::Params(
|
||||
params_ = typename UnderlyingKernel::Params(
|
||||
args,
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage));
|
||||
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<ImplicitGemmKernel>,
|
||||
cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<UnderlyingKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
@ -292,9 +298,9 @@ public:
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(32 * kWarpCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage));
|
||||
int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage));
|
||||
|
||||
cutlass::Kernel<ImplicitGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
cutlass::Kernel<UnderlyingKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ template <
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage
|
||||
/// pipeline.
|
||||
/// pipeline that supports all GroupMode.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
@ -135,6 +135,13 @@ struct DefaultConv2dGroupFprop <
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
static_assert(std::is_same<LayoutA, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
static_assert(std::is_same<LayoutB, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
static_assert(std::is_same<LayoutC, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
@ -215,6 +222,267 @@ struct DefaultConv2dGroupFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage
|
||||
/// pipeline that supports GroupMode::kSingleGroup.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dGroupFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
GroupMode::kSingleGroup,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
static_assert(std::is_same<LayoutA, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
static_assert(std::is_same<LayoutB, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
static_assert(std::is_same<LayoutC, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
Stages, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmMultistage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
arch::CacheOperation::Always,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv2dProblemSize,
|
||||
GroupMode::kSingleGroup
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// 2 stage pipeline that supports GroupMode::kSingleGroup.
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultConv2dGroupFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
MathOperatorTag,
|
||||
GroupMode::kSingleGroup,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
|
||||
static_assert(std::is_same<LayoutA, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
static_assert(std::is_same<LayoutB, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
static_assert(std::is_same<LayoutC, cutlass::layout::TensorNHWC>::value,
|
||||
"Current group conv only support NHWC layout");
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor,
|
||||
ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
|
||||
2, MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::AlignedArray<ElementA, AlignmentA>;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ThreadMapA,
|
||||
AccessTypeA
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::TileIterator<
|
||||
cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ThreadMapB,
|
||||
AccessTypeB
|
||||
>
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaTensorOp = typename MmaCore::MmaTensorOp;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::ImplicitGemmPipelined<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
MmaPolicy
|
||||
>;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename detail::DefaultConvEpilogue<
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp
|
||||
>::Epilogue;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv2dProblemSize,
|
||||
GroupMode::kSingleGroup
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -39,14 +39,21 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d.h"
|
||||
#include "cutlass/conv/kernel/direct_convolution.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h"
|
||||
|
||||
// Direct Conv Related Header files
|
||||
#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h"
|
||||
|
||||
#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -54,7 +61,7 @@ namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Conv2dFprop
|
||||
/// Defines a kernel for DepthwiseFprop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
@ -80,12 +87,43 @@ template <
|
||||
int AlignmentB = cutlass::sizeof_bits<ElementB>::value / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultDepthwiseFprop;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for DepthwiseFprop with direct convolution algorithm
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename OperatorClass,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename ThreadBlockOutputShape,
|
||||
typename FilterShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
conv::StrideSupport StrideSupport = StrideSupport::kStrided,
|
||||
// MatrixShape<Height, Width>
|
||||
typename StrideShape = cutlass::MatrixShape<-1, -1>,
|
||||
// MatrixShape< Height, Width>
|
||||
typename DilationShape = cutlass::MatrixShape<-1, -1>,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value
|
||||
> struct DefaultDepthwiseDirect2dConvFprop;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// OpClassSimt convolutions
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm,
|
||||
/// 2 stage pipeline, and FFMA-based mainloop for SM50
|
||||
/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
@ -210,6 +248,338 @@ struct DefaultDepthwiseFprop <
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Depthwise specialization for direct 2d conv implementation,
|
||||
/// multiple stage pipeline, and SIMT-based mainloop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename ThreadBlockOutputShape,
|
||||
typename FilterShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
typename StrideShape,
|
||||
typename DilationShape,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultDepthwiseDirect2dConvFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
ThreadBlockOutputShape,
|
||||
FilterShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kOptimized,
|
||||
StrideSupport,
|
||||
StrideShape,
|
||||
DilationShape,
|
||||
AlignmentA,
|
||||
AlignmentB
|
||||
> {
|
||||
// One warp handles the entrie groups per cta.
|
||||
static_assert(ThreadblockShape::kN == WarpShape::kN,
|
||||
"ThreadblockShape::kN should be same as WarpShape::kN ");
|
||||
static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount,
|
||||
"ThreadblockShape::kK and WarpShape::kK should be same as filter size");
|
||||
static_assert(ThreadblockShape::kM % WarpShape::kM == 0,
|
||||
"ThreadblockShape::kM must be divisible by WarpShape shape::kM");
|
||||
static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1");
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize<
|
||||
ThreadblockShape,
|
||||
ThreadBlockOutputShape,
|
||||
FilterShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementA,
|
||||
layout::RowMajor,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassSimt,
|
||||
128,
|
||||
128,
|
||||
Stages,
|
||||
MathOperatorTag>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM,ThreadblockShape::kN>, // < outputShape:KMNK, groups per cta>
|
||||
ThreadBlockOutputShape,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kN, FilterShape::kCount>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
using ThreadOutputShape = typename MmaCore::ThreadOutputShape;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * AlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt<
|
||||
ThreadblockShape, // < outputShape:KMNK, groups per cta>
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount,
|
||||
ThreadOutputShape,
|
||||
ThreadBlockOutputShape
|
||||
>::Epilogue;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
CacheOpA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages,
|
||||
Epilogue
|
||||
>;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::DirectConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv2dProblemSize,
|
||||
cutlass::conv::GroupMode::kDepthwise,
|
||||
ThreadBlockOutputShape
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Defines a kernel for Depthwise specialization for direct 2d conv implementation,
|
||||
/// multiple stage pipeline, and SIMT-based mainloop
|
||||
template <
|
||||
typename ElementA,
|
||||
typename LayoutA,
|
||||
typename ElementB,
|
||||
typename LayoutB,
|
||||
typename ElementC,
|
||||
typename LayoutC,
|
||||
typename ElementAccumulator,
|
||||
typename ArchTag,
|
||||
typename ThreadblockShape,
|
||||
typename ThreadBlockOutputShape,
|
||||
typename FilterShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename EpilogueOutputOp,
|
||||
typename ThreadblockSwizzle,
|
||||
int Stages,
|
||||
typename MathOperatorTag,
|
||||
conv::StrideSupport StrideSupport,
|
||||
typename StrideShape,
|
||||
typename DilationShape,
|
||||
int AlignmentA,
|
||||
int AlignmentB
|
||||
>
|
||||
struct DefaultDepthwiseDirect2dConvFprop <
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
ThreadBlockOutputShape,
|
||||
FilterShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kFixedStrideDilation,
|
||||
StrideSupport,
|
||||
StrideShape,
|
||||
DilationShape,
|
||||
AlignmentA,
|
||||
AlignmentB,
|
||||
> {
|
||||
|
||||
|
||||
|
||||
// One warp handles the entrie groups per cta.
|
||||
static_assert(ThreadblockShape::kN == WarpShape::kN,
|
||||
"ThreadblockShape::kN should be same as WarpShape::kN ");
|
||||
static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount,
|
||||
"ThreadblockShape::kK and WarpShape::kK should be same as filter size");
|
||||
static_assert(ThreadblockShape::kM % WarpShape::kM == 0,
|
||||
"ThreadblockShape::kM must be divisible by WarpShape shape::kM");
|
||||
static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1");
|
||||
|
||||
static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed");
|
||||
static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed");
|
||||
|
||||
// Activations loaded by threadblock
|
||||
static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow +
|
||||
(FilterShape::kRow - 1) * DilationShape::kRow + 1;
|
||||
|
||||
static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn +
|
||||
(FilterShape::kColumn - 1) * DilationShape::kColumn + 1;
|
||||
|
||||
using ActivationShape =
|
||||
cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >;
|
||||
|
||||
// Define the core components from GEMM
|
||||
using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize<
|
||||
ThreadblockShape,
|
||||
ThreadBlockOutputShape,
|
||||
FilterShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementA,
|
||||
layout::RowMajor,
|
||||
ElementB,
|
||||
layout::ColumnMajor,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassSimt,
|
||||
128,
|
||||
128,
|
||||
Stages,
|
||||
MathOperatorTag,
|
||||
IteratorAlgorithm::kFixedStrideDilation,
|
||||
StrideShape,
|
||||
DilationShape,
|
||||
ActivationShape>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using IteratorA =
|
||||
cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM,ThreadblockShape::kN>, // < outputShape:KMNK, groups per cta>
|
||||
ThreadBlockOutputShape,
|
||||
StrideShape,
|
||||
DilationShape,
|
||||
ActivationShape,
|
||||
ElementA, LayoutA,
|
||||
ThreadMapA
|
||||
>;
|
||||
|
||||
using SmemIteratorA = typename MmaCore::SmemIteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::AlignedArray<ElementB, AlignmentB>;
|
||||
using IteratorB =
|
||||
cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<
|
||||
cutlass::MatrixShape<ThreadblockShape::kN, FilterShape::kCount>,
|
||||
ElementB, LayoutB,
|
||||
ThreadMapB
|
||||
>;
|
||||
|
||||
using SmemIteratorB = typename MmaCore::SmemIteratorB;
|
||||
|
||||
// Warp-level GEMM components
|
||||
using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt;
|
||||
using MmaPolicy = typename MmaCore::MmaPolicy;
|
||||
using ThreadOutputShape = typename MmaCore::ThreadOutputShape;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * AlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<ElementB>::value * AlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the epilogue
|
||||
using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt<
|
||||
ThreadblockShape, // < outputShape:KMNK, groups per cta>
|
||||
WarpMmaSimtOp,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount,
|
||||
ThreadOutputShape,
|
||||
ThreadBlockOutputShape
|
||||
>::Epilogue;
|
||||
|
||||
// Define the Mma
|
||||
using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage<
|
||||
ThreadblockShape,
|
||||
IteratorA,
|
||||
SmemIteratorA,
|
||||
CacheOpA,
|
||||
IteratorB,
|
||||
SmemIteratorB,
|
||||
CacheOpB,
|
||||
MmaPolicy,
|
||||
Stages,
|
||||
Epilogue,
|
||||
IteratorAlgorithm::kFixedStrideDilation
|
||||
>;
|
||||
|
||||
// Define the kernel
|
||||
using Kernel = cutlass::conv::kernel::DirectConvolution<
|
||||
Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
conv::Operator::kFprop,
|
||||
Conv2dProblemSize,
|
||||
cutlass::conv::GroupMode::kDepthwise,
|
||||
ThreadBlockOutputShape
|
||||
>;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
|
||||
505
include/cutlass/conv/kernel/direct_convolution.h
Normal file
505
include/cutlass/conv/kernel/direct_convolution.h
Normal file
@ -0,0 +1,505 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multi-staged Depthwise Convolution kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/conv3d_problem_size.h"
|
||||
#include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename Arguments_, ///! Kernel Arguments
|
||||
typename ConvOutputIteratorParameter_, ///! Output Iterator Params
|
||||
typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem
|
||||
conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode
|
||||
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> > ///! OutputShape per ThreadBlock
|
||||
struct DirectConvolutionParams {
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
static Operator const kConvolutionalOperator = ConvOperator;
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
using Arguments = Arguments_;
|
||||
using ConvOutputIteratorParameter = ConvOutputIteratorParameter_;
|
||||
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
|
||||
static conv::GroupMode const kGroupMode = GroupMode_;
|
||||
static int const kStages = Mma::kStages;
|
||||
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
gemm::GemmCoord implicit_gemm_problem_size;
|
||||
int swizzle_log_tile;
|
||||
int smem_size_;
|
||||
|
||||
int gemm_k_iterations;
|
||||
int gemm_k_iterations_per_channel;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
typename Mma::IteratorB::Params iterator_B;
|
||||
typename Mma::IteratorB::Element const *ptr_B;
|
||||
typename Mma::IteratorB::Element *ptr_reordered_B;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_C;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_C;
|
||||
typename Epilogue::OutputTileIterator::Params iterator_D;
|
||||
typename Epilogue::OutputTileIterator::Element *ptr_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
SplitKMode split_k_mode;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr)
|
||||
: problem_size(args.problem_size),
|
||||
implicit_gemm_problem_size(
|
||||
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)),
|
||||
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())),
|
||||
ptr_B(args.ref_B.data()),
|
||||
ptr_reordered_B(args.ref_reordered_B.data()),
|
||||
iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size),
|
||||
ptr_C(args.ref_C.data()),
|
||||
iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size),
|
||||
ptr_D(args.ref_D.data()),
|
||||
output_op(args.output_op),
|
||||
semaphore(semaphore),
|
||||
split_k_mode(args.split_k_mode),
|
||||
split_k_slices(args.problem_size.split_k_slices) {
|
||||
gemm_k_iterations =
|
||||
depthwise_gemm_k_iterations<ThreadBlockOutputShape::kN,
|
||||
ThreadBlockOutputShape::kH,
|
||||
ThreadBlockOutputShape::kW>(kConvolutionalOperator,
|
||||
ThreadblockShape::kK,
|
||||
args.problem_size,
|
||||
kIteratorAlgorithm,
|
||||
kGroupMode,
|
||||
ThreadblockShape::kN);
|
||||
|
||||
gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel(
|
||||
kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
kConvolutionalOperator,
|
||||
problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.problem_size.split_k_slices);
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
|
||||
// Dynamic SMEM usage because stride and dilation are runtime params.
|
||||
smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_smem_size() {
|
||||
// Dynamic Smem Size
|
||||
return smem_size_;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename Params_, typename ElementB_>
|
||||
struct ReorderKernel {
|
||||
using Params = Params_;
|
||||
using ElementB = ElementB_;
|
||||
|
||||
union SharedStorage {};
|
||||
|
||||
static unsigned int const kReorderKernelThreadPerCTA = 128;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ReorderKernel() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static dim3 get_grid_shape(Params const ¶ms) {
|
||||
return dim3{static_cast<unsigned int>(
|
||||
(params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) /
|
||||
kReorderKernelThreadPerCTA),
|
||||
1,
|
||||
1};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
int64_t m = static_cast<int64_t>(params.problem_size.groups);
|
||||
int64_t n = static_cast<int64_t>(params.problem_size.filter_size() / params.problem_size.K);
|
||||
const ElementB *src_with_type = static_cast<const ElementB *>(params.ptr_B);
|
||||
ElementB *dst_with_type = static_cast<ElementB *>(params.ptr_reordered_B);
|
||||
|
||||
int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x;
|
||||
int64_t index_m = linear_index / n;
|
||||
int64_t index_n = linear_index % n;
|
||||
int64_t new_linear_index = index_m + index_n * m;
|
||||
|
||||
if (linear_index < m * n) {
|
||||
dst_with_type[new_linear_index] = src_with_type[linear_index];
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
|
||||
typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem
|
||||
conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode
|
||||
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>
|
||||
>
|
||||
struct DirectConvolution {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
static Operator const kConvolutionalOperator = ConvOperator;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename EpilogueOutputOp::ElementOutput;
|
||||
|
||||
/// Set output tensor C layout
|
||||
using LayoutC = LayoutA;
|
||||
|
||||
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
|
||||
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using WarpMmaOperator = typename Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename WarpMmaOperator::Shape;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm;
|
||||
static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using TensorRefA = typename Mma::IteratorA::TensorRef;
|
||||
using TensorRefB = typename Mma::IteratorB::TensorRef;
|
||||
using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
|
||||
|
||||
/// Check iterator A and B convolution dimension are the same and
|
||||
// set device::ImplicitGemmConvolution::kConvDim
|
||||
static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim,
|
||||
"Convolution on different different dimensions is not supported");
|
||||
static int const kConvDim = Mma::IteratorA::kConvDim;
|
||||
|
||||
/// Conv dimension and problem size structure (Conv2d or Conv3d)
|
||||
using ConvProblemSize = ConvProblemSize_;
|
||||
|
||||
static conv::GroupMode const kGroupMode = GroupMode_;
|
||||
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
|
||||
LayoutC,
|
||||
typename Epilogue::OutputTileIterator::Layout,
|
||||
TensorRefC,
|
||||
ConvOperator,
|
||||
ConvProblemSize
|
||||
>;
|
||||
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ConvProblemSize problem_size;
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
TensorRefB ref_reordered_B;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
SplitKMode split_k_mode;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size
|
||||
):
|
||||
problem_size(problem_size) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
ConvProblemSize const & problem_size,
|
||||
TensorRefA const & ref_A,
|
||||
TensorRefB const & ref_B,
|
||||
TensorRefC const & ref_C,
|
||||
TensorRefC const & ref_D,
|
||||
typename EpilogueOutputOp::Params const & output_op,
|
||||
TensorRefB const & ref_reordered_B = nullptr,
|
||||
SplitKMode const & split_k_mode = SplitKMode::kSerial
|
||||
):
|
||||
problem_size(problem_size),
|
||||
ref_A(ref_A),
|
||||
ref_B(ref_B),
|
||||
ref_C(ref_C),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
ref_reordered_B(ref_reordered_B),
|
||||
split_k_mode(split_k_mode)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
using Params =
|
||||
typename cutlass::conv::kernel::DirectConvolutionParams<Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle,
|
||||
kConvolutionalOperator,
|
||||
Arguments,
|
||||
ConvOutputIteratorParameter,
|
||||
ConvProblemSize,
|
||||
kGroupMode,
|
||||
ThreadBlockOutputShape>;
|
||||
|
||||
using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel<Params, ElementB>;
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DirectConvolution() { }
|
||||
|
||||
/// Executes one ImplicitGEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if threadblock is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
int iterator_column_offset = 0;
|
||||
int filter_row_offset = 0;
|
||||
if (kGroupMode != GroupMode::kNone) {
|
||||
if (kGroupMode == GroupMode::kDepthwise) {
|
||||
iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN;
|
||||
}
|
||||
}
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.iterator_A,
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() + threadblock_tile_idx.k(),
|
||||
iterator_column_offset
|
||||
)
|
||||
);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.iterator_B,
|
||||
params.problem_size,
|
||||
params.ptr_reordered_B,
|
||||
thread_idx,
|
||||
MatrixCoord(
|
||||
filter_row_offset,
|
||||
iterator_column_offset
|
||||
)
|
||||
);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
// Compute logical position within grid
|
||||
threadblock_tile_idx =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_idx.m() + threadblock_tile_idx.k(),
|
||||
threadblock_tile_idx.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.iterator_D,
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator reading from source accumulator tensor
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.iterator_C,
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
// Epilogue is fused in the mainloop
|
||||
mma(params.gemm_k_iterations,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
params.iterator_A,
|
||||
iterator_B,
|
||||
params.iterator_B,
|
||||
accumulators,
|
||||
epilogue,
|
||||
output_op,
|
||||
iterator_D,
|
||||
iterator_C,
|
||||
params.split_k_slices);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
325
include/cutlass/conv/thread/depthwise_mma.h
Normal file
325
include/cutlass/conv/thread/depthwise_mma.h
Normal file
@ -0,0 +1,325 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates exposing architecture support for depthwise convolution
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/thread/mma.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// MMA operation
|
||||
template <
|
||||
/// Size of the matrix product (concept: GemmShape)
|
||||
typename Shape_,
|
||||
/// Number of threads participating
|
||||
int kThreads_,
|
||||
/// Data type of A elements
|
||||
typename ElementA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Inner product operator
|
||||
typename Operator
|
||||
>
|
||||
struct ElementwiseInnerProduct;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// General implementation
|
||||
template <
|
||||
/// Size of the matrix product (concept: GemmShape)
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_>
|
||||
struct ElementwiseInnerProduct<Shape_, 1, ElementA_, ElementB_, ElementC_, arch::OpMultiplyAdd> {
|
||||
using Shape = Shape_;
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
using ElementC = ElementC_;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(Array<ElementC_, Shape::kN> &d,
|
||||
Array<ElementA_, Shape::kN> const &a,
|
||||
Array<ElementB_, Shape::kN> const &b,
|
||||
Array<ElementC_, Shape::kN> const &c) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Shape::kN; ++i) {
|
||||
d[i] = a[i] * b[i] + c[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization of half_t
|
||||
template <>
|
||||
struct ElementwiseInnerProduct<
|
||||
gemm::GemmShape<2, 2, 1>,
|
||||
1,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
arch::OpMultiplyAdd> {
|
||||
|
||||
using Shape = gemm::GemmShape<2, 2, 1>;
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
using ElementC = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
Array<half_t, 2> &d,
|
||||
Array<half_t, 2> const &a,
|
||||
Array<half_t, 2> const &b,
|
||||
Array<half_t, 2> const &c
|
||||
) {
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
|
||||
|
||||
__half2 const & A = reinterpret_cast<__half2 const &>(a);
|
||||
__half2 const & B = reinterpret_cast<__half2 const &>(b);
|
||||
__half2 const & C = reinterpret_cast<__half2 const &>(c);
|
||||
|
||||
__half2 tmp_D = __hfma2(A, B, C);
|
||||
|
||||
d = reinterpret_cast<Array<half_t, 2> const &>(tmp_D);
|
||||
|
||||
#else
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
d[i] = a[i] * b[i] + c[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape,
|
||||
/// Data type of A elements
|
||||
typename ElementA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Concept: arch::OpMultiplyAdd or arch::Mma<>
|
||||
typename Operator = arch::OpMultiplyAdd,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
struct DepthwiseDirectConvElementwiseInnerProduct;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Gemplate that handles all packed matrix layouts
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Operator used to compute GEMM
|
||||
typename Operator_
|
||||
>
|
||||
struct DepthwiseDirectConvElementwiseInnerProductGeneric {
|
||||
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of operand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Data type of operand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Element type of operand C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = Operator_;
|
||||
|
||||
/// A operand storage
|
||||
using FragmentA = Array<ElementA, Shape::kMN>;
|
||||
|
||||
/// B operand storage
|
||||
using FragmentB = Array<ElementB, Shape::kN>;
|
||||
|
||||
/// C operand storage
|
||||
using FragmentC = Array<ElementC, Shape::kMN>;
|
||||
|
||||
/// Instruction
|
||||
using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct<
|
||||
gemm::GemmShape<Shape::kN, Shape::kN, 1>,
|
||||
1,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC,
|
||||
Operator>;
|
||||
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Computes a matrix product D = A * B + C
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentC & D,
|
||||
FragmentA const & A,
|
||||
FragmentB const & B,
|
||||
FragmentC const & C) {
|
||||
Array<ElementC, Shape::kN> *ptr_D = reinterpret_cast<Array<ElementC, Shape::kN> *>(&D);
|
||||
Array<ElementA, Shape::kN> const *ptr_A =
|
||||
reinterpret_cast<Array<ElementA, Shape::kN> const *>(&A);
|
||||
Array<ElementB, Shape::kN> const *ptr_B =
|
||||
reinterpret_cast<Array<ElementB, Shape::kN> const *>(&B);
|
||||
|
||||
MmaOp mma_op;
|
||||
|
||||
// Copy accumulators
|
||||
D = C;
|
||||
|
||||
// Compute matrix product
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < Shape::kM; ++m) {
|
||||
|
||||
Array<ElementC, MmaOp::Shape::kN> tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n];
|
||||
Array<ElementA, MmaOp::Shape::kN> tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n];
|
||||
Array<ElementB, MmaOp::Shape::kN> tmpB = ptr_B[n];
|
||||
|
||||
mma_op(tmpD, tmpA, tmpB, tmpD);
|
||||
|
||||
ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_
|
||||
>
|
||||
struct DepthwiseDirectConvElementwiseInnerProduct<
|
||||
Shape_,
|
||||
ElementA_,
|
||||
ElementB_,
|
||||
ElementC_,
|
||||
arch::OpMultiplyAdd
|
||||
> {
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of operand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Data type of operand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Element type of operand C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Underlying mathematical operator
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
|
||||
/// A operand storage
|
||||
using FragmentA =
|
||||
Array<ElementA, Shape::kMN>; // output_tile_size per thread * groups_per_thread
|
||||
|
||||
/// B operand storage
|
||||
using FragmentB = Array<ElementB, Shape::kN>; // 1 * groups_per_thread
|
||||
|
||||
/// C operand storage
|
||||
using FragmentC =
|
||||
Array<ElementC, Shape::kMN>; // output_tile_size per thread * groups_per_thread
|
||||
|
||||
static bool const use_optimized = 0;
|
||||
|
||||
using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric<Shape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC,
|
||||
Operator>;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Computes a matrix product D = A * B + C
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentC & D,
|
||||
FragmentA const & A,
|
||||
FragmentB const & B,
|
||||
FragmentC const & C) {
|
||||
|
||||
ArchMmaOperator mma;
|
||||
|
||||
mma(D, A, B, C);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
@ -145,6 +145,7 @@ private:
|
||||
uint32_t predicates_[kAccessesPerVector];
|
||||
int filter_rs_;
|
||||
int filter_c_;
|
||||
int channels_per_group_;
|
||||
|
||||
//
|
||||
// Assertions
|
||||
@ -175,6 +176,7 @@ public:
|
||||
|
||||
filter_c_ = threadblock_offset.row() + thread_coord.contiguous();
|
||||
Index column = threadblock_offset.column() + thread_coord.strided();
|
||||
channels_per_group_ = problem_size_.C / problem_size_.groups;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
@ -188,7 +190,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_);
|
||||
}
|
||||
|
||||
pointer_ += (
|
||||
@ -229,7 +231,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C);
|
||||
clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_);
|
||||
}
|
||||
|
||||
pointer_ += next;
|
||||
|
||||
230
include/cutlass/conv/threadblock/depthwise_direct_conv_params.h
Normal file
230
include/cutlass/conv/threadblock/depthwise_direct_conv_params.h
Normal file
@ -0,0 +1,230 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Extracts the host-params objects into non-template code.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED
|
||||
#include <fstream>
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized
|
||||
template<typename Layout_ = layout::TensorNHWC >
|
||||
struct Depthwise2dFpropDirectConvParams;
|
||||
|
||||
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation
|
||||
template<typename Layout_ = layout::TensorNHWC >
|
||||
struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams;
|
||||
|
||||
/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized
|
||||
template<typename Layout_ = layout::TensorNHWC >
|
||||
struct Depthwise2dFpropDirectConvFilterIteratorParams;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized
|
||||
template<>
|
||||
struct Depthwise2dFpropDirectConvParams<layout::TensorNHWC> {
|
||||
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int32_t activation_tile_h;
|
||||
int32_t activation_tile_w;
|
||||
int32_t activation_tile_hw;
|
||||
FastDivmod activation_tile_w_divmod;
|
||||
|
||||
int filter[2];
|
||||
int stride[2];
|
||||
int dilation[2];
|
||||
int inc_next[2];
|
||||
FastDivmod pq_divmod;
|
||||
FastDivmod q_divmod;
|
||||
|
||||
int activation_load_count;
|
||||
int activation_storage_elements;
|
||||
int activation_size;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Depthwise2dFpropDirectConvParams() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Depthwise2dFpropDirectConvParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< layout object
|
||||
MatrixCoord threadblock_shape, ///< CTA threadblock Shape
|
||||
Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock
|
||||
const int element_size_bits, ///< bits of activation element
|
||||
const int thread_count, ///< threads per threadblock
|
||||
const int thread_count_contiguous, ///< number of threads for continuous dimension
|
||||
const int element_per_load) ///< element per each load
|
||||
: layout(layout) {
|
||||
|
||||
filter[0] = problem_size.S;
|
||||
filter[1] = problem_size.R;
|
||||
|
||||
stride[0] = problem_size.stride_w;
|
||||
stride[1] = problem_size.stride_h;
|
||||
|
||||
dilation[0] = problem_size.dilation_w;
|
||||
dilation[1] = problem_size.dilation_h;
|
||||
|
||||
// Compute activation_tile size per threadblock because stride and dilation are runtime params.
|
||||
activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h +
|
||||
(problem_size.R - 1) * problem_size.dilation_h + 1;
|
||||
activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w +
|
||||
(problem_size.S - 1) * problem_size.dilation_w + 1;
|
||||
activation_tile_hw = activation_tile_h * activation_tile_w;
|
||||
|
||||
activation_tile_w_divmod = FastDivmod(activation_tile_w);
|
||||
|
||||
/// Below two values could not be templatized because the stride and dilation are runtime params
|
||||
activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count;
|
||||
activation_storage_elements = activation_load_count * element_per_load * thread_count;
|
||||
activation_size = activation_storage_elements * element_size_bits / 8;
|
||||
|
||||
// Fastdivmod for output P, Q
|
||||
int tiles_p =
|
||||
(problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h());
|
||||
int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) /
|
||||
(threadblock_output_shape.w());
|
||||
|
||||
pq_divmod = FastDivmod(tiles_p * tiles_q);
|
||||
q_divmod = FastDivmod(tiles_q);
|
||||
|
||||
// next S
|
||||
inc_next[0] = problem_size.dilation_w;
|
||||
// next R
|
||||
inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation
|
||||
template <>
|
||||
struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams<layout::TensorNHWC> {
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
FastDivmod pq_divmod;
|
||||
FastDivmod q_divmod;
|
||||
|
||||
int activation_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< Layout object
|
||||
MatrixCoord threadblock_shape, ///< Threadblock Shape
|
||||
Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock
|
||||
const int activation_size_ ///< Activation size loaded by iterator
|
||||
)
|
||||
: layout(layout),
|
||||
activation_size(activation_size_) {
|
||||
// Fastdivmod for output P, Q
|
||||
int tiles_p =
|
||||
(problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h());
|
||||
int tiles_q =
|
||||
(problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w());
|
||||
|
||||
pq_divmod = FastDivmod(tiles_p * tiles_q);
|
||||
q_divmod = FastDivmod(tiles_q);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized
|
||||
template <>
|
||||
struct Depthwise2dFpropDirectConvFilterIteratorParams<layout::TensorNHWC> {
|
||||
using Layout = layout::TensorNHWC;
|
||||
|
||||
Layout layout;
|
||||
|
||||
int filter_size;
|
||||
|
||||
bool is_convolution;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Depthwise2dFpropDirectConvFilterIteratorParams() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Depthwise2dFpropDirectConvFilterIteratorParams(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Layout const &layout, ///< Layout object
|
||||
MatrixCoord threadblock_shape, ///< Threadblock Shape
|
||||
const int filter_size_) ///< Filter size loaded by iterator
|
||||
: layout(layout),
|
||||
filter_size(filter_size_),
|
||||
is_convolution(problem_size.mode == Mode::kConvolution){}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,314 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape_,
|
||||
typename OutputTileShape_,
|
||||
typename StrideShape_,
|
||||
typename DilationShape_,
|
||||
typename ActivationShape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess> >
|
||||
class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation {
|
||||
public:
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using OutputTileShape = OutputTileShape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
// Compilation value of stride , dialtion and activation shape
|
||||
using StrideShape = StrideShape_;
|
||||
using DilationShape = DilationShape_;
|
||||
using ActivationShape = ActivationShape_;
|
||||
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1");
|
||||
|
||||
static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1");
|
||||
static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams<Layout>;
|
||||
|
||||
private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
Params const ¶ms_;
|
||||
char const *pointer_;
|
||||
|
||||
// Base channels for current threadblock
|
||||
int base_c_;
|
||||
// Base activation index for current threadblock
|
||||
int offset_intial_npq_;
|
||||
// Base activation coord for current threadblock
|
||||
TensorCoord activatioin_base_;
|
||||
// Intial thread positioin
|
||||
int offset_initial_hwc_;
|
||||
// Overall load instruction per thread.
|
||||
int iterator_load_;
|
||||
// thread loading position.
|
||||
int iterator_hwc_;
|
||||
// activation N is inside the Tensor or not
|
||||
bool valid_n_;
|
||||
|
||||
public:
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset =
|
||||
MatrixCoord()
|
||||
)
|
||||
: params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
offset_intial_npq_(threadblock_offset.row()),
|
||||
offset_initial_hwc_(thread_idx),
|
||||
iterator_load_(0) {
|
||||
|
||||
base_c_ = threadblock_offset.column();
|
||||
|
||||
set_iteration_index(0);
|
||||
|
||||
set_activation_coord(offset_intial_npq_);
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_activation_coord(int offset_npq) {
|
||||
int offset_inital_n, offset_inital_p, offset_inital_q;
|
||||
int residual;
|
||||
|
||||
params_.pq_divmod(offset_inital_n, residual, offset_npq);
|
||||
params_.q_divmod(offset_inital_p, offset_inital_q, residual);
|
||||
|
||||
int base_n = offset_inital_n;
|
||||
|
||||
int base_h =
|
||||
offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h;
|
||||
|
||||
int base_w =
|
||||
offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w;
|
||||
|
||||
activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_);
|
||||
|
||||
valid_n_ = activatioin_base_.n() < problem_size_.N;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
|
||||
return Params(
|
||||
problem_size,
|
||||
layout,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
{OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC},
|
||||
kActivationSize);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads;
|
||||
iterator_load_ = index;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// Go to next threadblock
|
||||
offset_intial_npq_ += problem_size_.split_k_slices;
|
||||
|
||||
set_iteration_index(0);
|
||||
|
||||
set_activation_coord(offset_intial_npq_);
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activations tensor X that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ;
|
||||
int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ;
|
||||
int h = next / ActivationShape::kW;
|
||||
int w = next % ActivationShape::kW;
|
||||
|
||||
c = c * AccessType::kElements;
|
||||
|
||||
return activatioin_base_ + TensorCoord(0, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor X
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
bool valid_c = coord.c() < problem_size_.C;
|
||||
bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H;
|
||||
bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W;
|
||||
return valid_n_ ? valid_c & valid_h & valid_w : 0;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
AccessType const *ptr =
|
||||
reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() {
|
||||
|
||||
++iterator_load_;
|
||||
iterator_hwc_ += ThreadMap::kThreads;
|
||||
|
||||
if (iterator_load_ < ThreadMap::Iterations::kCount) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iterator_load_ = 0;
|
||||
iterator_hwc_ = offset_initial_hwc_;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines the activation size loaded by iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_load_size() {
|
||||
return kActivationSize;
|
||||
}
|
||||
|
||||
/// Determines the iterations needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_iteration_num() {
|
||||
return ThreadMap::Iterations::kCount;
|
||||
}
|
||||
|
||||
/// Determines whether the Depthwise fprop can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check stride and dilation constraint
|
||||
if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,291 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape_,
|
||||
typename OutputTileShape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess> >
|
||||
class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized {
|
||||
public:
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using OutputTileShape = OutputTileShape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1");
|
||||
|
||||
static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1");
|
||||
static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
|
||||
using Params = Depthwise2dFpropDirectConvParams<Layout>;
|
||||
|
||||
private:
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
Params const ¶ms_;
|
||||
char const *pointer_;
|
||||
|
||||
// Base channels for current threadblock
|
||||
int base_c_;
|
||||
// Base activation index for current threadblock
|
||||
int offset_intial_npq_;
|
||||
// Base activation coord for current threadblock
|
||||
TensorCoord activatioin_base_;
|
||||
// Intial thread positioin
|
||||
int offset_initial_hwc_;
|
||||
// Overall load instruction per thread.
|
||||
int iterator_load_;
|
||||
// thread loading position.
|
||||
int iterator_hwc_;
|
||||
// Number of loads for activations tensor X.
|
||||
const int number_of_loads_;
|
||||
|
||||
public:
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset =
|
||||
MatrixCoord()
|
||||
)
|
||||
: params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
offset_intial_npq_(threadblock_offset.row()),
|
||||
offset_initial_hwc_(thread_idx),
|
||||
iterator_load_(0),
|
||||
number_of_loads_(params.activation_load_count) {
|
||||
|
||||
base_c_ = threadblock_offset.column();
|
||||
|
||||
set_activation_coord(offset_intial_npq_);
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_activation_coord(int offset_npq) {
|
||||
int offset_inital_n, offset_inital_p, offset_inital_q;
|
||||
int residual;
|
||||
|
||||
params_.pq_divmod(offset_inital_n, residual, offset_npq);
|
||||
params_.q_divmod(offset_inital_p, offset_inital_q, residual);
|
||||
|
||||
int base_n = offset_inital_n;
|
||||
|
||||
int base_h =
|
||||
offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h;
|
||||
|
||||
int base_w =
|
||||
offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w;
|
||||
|
||||
activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
|
||||
return Params(
|
||||
problem_size,
|
||||
layout,
|
||||
{Shape::kRow, Shape::kColumn},
|
||||
{OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC},
|
||||
sizeof_bits<Element>::value,
|
||||
ThreadMap::kThreads,
|
||||
ThreadMap::Detail::ShapeVec::kContiguous,
|
||||
ThreadMap::kElementsPerAccess);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads;
|
||||
iterator_load_ = index;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// Go to next threadblock
|
||||
offset_intial_npq_ += problem_size_.split_k_slices;
|
||||
|
||||
set_activation_coord(offset_intial_npq_);
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the activations tensor X that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ;
|
||||
int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ;
|
||||
int h, w;
|
||||
params_.activation_tile_w_divmod(h, w, next) ;
|
||||
|
||||
c = c * AccessType::kElements;
|
||||
|
||||
return activatioin_base_ + TensorCoord(0, h, w, c);
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor X
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H &&
|
||||
coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
TensorCoord coord = at();
|
||||
LongIndex offset = params_.layout(coord);
|
||||
|
||||
AccessType const *ptr =
|
||||
reinterpret_cast<AccessType const *>(pointer_ + offset * sizeof_bits<Element>::value / 8);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() {
|
||||
|
||||
++iterator_load_;
|
||||
iterator_hwc_ += ThreadMap::kThreads;
|
||||
|
||||
if (iterator_load_ < number_of_loads_) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iterator_load_ = 0;
|
||||
iterator_hwc_ = offset_initial_hwc_;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines the activation size loaded by iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_load_size() {
|
||||
return params_.activation_size;
|
||||
}
|
||||
|
||||
/// Determines the iterations needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_iteration_num() {
|
||||
return number_of_loads_;
|
||||
}
|
||||
|
||||
/// Determines whether the Depthwise fprop can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.C % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,551 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Epilogue stores the data into global memory
|
||||
typename Epilogue_,
|
||||
/// iterator implementation variants
|
||||
conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class DepthwiseFpropDirectConvMultipleStage :
|
||||
public DepthwiseDirectConvMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DepthwiseDirectConvMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using Epilogue = Epilogue_;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
|
||||
using ElementC = typename Policy::Operator::ElementC;
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DepthwiseFpropDirectConvMultipleStage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(IteratorA &iterator_A,
|
||||
IteratorB &iterator_B,
|
||||
int group_start_A = 0,
|
||||
int group_start_B = 0) {
|
||||
if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) {
|
||||
// Number of iterators is a static value.
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
} else {
|
||||
// Number of iterators is a runtime value.
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < iterator_A.get_iteration_num(); ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA &iterator_A,
|
||||
///< Params of global memory iterator
|
||||
typename IteratorA::Params const &iterator_a_params,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB &iterator_B,
|
||||
///< Params of global memory iterator
|
||||
typename IteratorB::Params const &iterator_b_params,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum,
|
||||
/// Epilogue
|
||||
Epilogue &epilogue,
|
||||
///< Output operator
|
||||
typename Epilogue::OutputOp const &output_op,
|
||||
///< Tile iterator for destination
|
||||
typename Epilogue::OutputTileIterator &destination_iterator,
|
||||
///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
typename Epilogue::OutputTileIterator &source_iterator,
|
||||
|
||||
int split_k_slices = 1
|
||||
) {
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
|
||||
|
||||
if (stage == 0) {
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
|
||||
if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){
|
||||
// Number of iterators is compilation static.
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
} else {
|
||||
// Number of iterators is a runtime value.
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num());
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < iterator_A.get_iteration_num(); ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({1, 0});
|
||||
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Waits until kStages-2 stages have committed.
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B[2];
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params);
|
||||
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0],
|
||||
warp_loaded_frag_A[0], warp_loaded_frag_B[0]);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
unsigned int iterations = 0;
|
||||
constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2);
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile.
|
||||
|
||||
accum.clear();
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) {
|
||||
if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) {
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
warp_loaded_frag_A[warp_mma_k % 2],
|
||||
warp_loaded_frag_B[warp_mma_k % 2]);
|
||||
|
||||
// Issue global->shared copies for the next stage
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
group_start_iteration_A = 0;
|
||||
group_start_iteration_B = 0;
|
||||
copy_tiles_and_advance(
|
||||
iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
}
|
||||
|
||||
if (warp_mma_k < Base::kWarpGemmIterations) {
|
||||
warp_mma(
|
||||
accum,
|
||||
warp_transformed_frag_A[warp_mma_k % 2],
|
||||
warp_transformed_frag_B[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 1 == inner_loop_iterations)
|
||||
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if (warp_mma_k + 2 == inner_loop_iterations) {
|
||||
// Inserts a fence to group cp.async instructions into stages.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Waits until kStages-2 stages of cp.async have committed
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next cta
|
||||
iterator_A.advance();
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
||||
this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0});
|
||||
|
||||
smem_write_stage_idx = 0;
|
||||
} else {
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
||||
this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size());
|
||||
smem_read_stage_idx = 0;
|
||||
} else {
|
||||
this->warp_tile_iterator_A_.advance(iterator_A.get_load_size());
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) {
|
||||
this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params);
|
||||
}
|
||||
|
||||
// goback to start position. B has no multiple stage
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0});
|
||||
|
||||
--gemm_k_iterations;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size();
|
||||
|
||||
destination_iterator.set_tile_index(iterations * split_k_slices);
|
||||
|
||||
source_iterator.set_tile_index(iterations * split_k_slices);
|
||||
|
||||
epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset);
|
||||
|
||||
++iterations;
|
||||
}
|
||||
|
||||
// Insert fence and wait for all outstanding cp.async operations to commit.
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,261 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile)
|
||||
matrix from memory.
|
||||
|
||||
This iterator assumes TensorNHWC layout of tensors in Global Memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_params.h"
|
||||
#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename Shape_,
|
||||
typename Element_,
|
||||
typename Layout_,
|
||||
typename ThreadMap_,
|
||||
typename AccessType_ = cutlass::AlignedArray<Element_, ThreadMap_::kElementsPerAccess> >
|
||||
class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized {
|
||||
public:
|
||||
//
|
||||
// Types
|
||||
//
|
||||
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = Layout_;
|
||||
using ThreadMap = ThreadMap_;
|
||||
using AccessType = AccessType_;
|
||||
using TensorRef = cutlass::TensorRef<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized;
|
||||
static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
|
||||
static int const kConvDim = 2;
|
||||
using ConvProblemSize = typename conv::Conv2dProblemSize;
|
||||
|
||||
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
||||
|
||||
static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads *
|
||||
sizeof_bits<Element>::value / 8;
|
||||
|
||||
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
||||
"Vectors implied by the thread map must be divisible by the access type.");
|
||||
|
||||
//
|
||||
// Simplifying assertions
|
||||
//
|
||||
static_assert(ThreadMap::Iterations::kContiguous == 1,
|
||||
"Require Iterations::kContiguous == 1");
|
||||
|
||||
//
|
||||
// Parameters structure
|
||||
//
|
||||
using Params = Depthwise2dFpropDirectConvFilterIteratorParams<Layout>;
|
||||
|
||||
protected:
|
||||
|
||||
Conv2dProblemSize const &problem_size_;
|
||||
Params const ¶ms_;
|
||||
LongIndex iteration_contiguous_;
|
||||
LongIndex iteration_strided_;
|
||||
LongIndex iteration_vector_;
|
||||
char const *pointer_;
|
||||
|
||||
int filter_k_;
|
||||
int offset_trs_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
public:
|
||||
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized(
|
||||
Params const ¶ms,
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
params_(params),
|
||||
problem_size_(problem_size),
|
||||
pointer_(reinterpret_cast<char const *>(ptr)),
|
||||
filter_k_(0) {
|
||||
|
||||
layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
filter_k_ = threadblock_offset.column() + thread_coord.contiguous();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) {
|
||||
return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize);
|
||||
}
|
||||
|
||||
/// Overrides the internal iteration index
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_iteration_index(Index index) {
|
||||
iteration_vector_ = index % kAccessesPerVector;
|
||||
int residual_access = index / kAccessesPerVector;
|
||||
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
||||
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
pointer_ += pointer_offset * 8 / sizeof_bits<Element>::value;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance() {
|
||||
// Do nothing because the filter is persistent in the SMEM
|
||||
}
|
||||
|
||||
/// Returns the coordinate in the filter tensor W that is currently pointed to
|
||||
/// by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorCoord at() const {
|
||||
|
||||
int k = filter_k_ + iteration_vector_ * AccessType::kElements;
|
||||
int trs = offset_trs_[iteration_strided_];
|
||||
|
||||
return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix
|
||||
}
|
||||
|
||||
/// Returns true if the current coordinate is within the activations tensor W
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const {
|
||||
|
||||
TensorCoord coord = at();
|
||||
|
||||
return coord.n() < problem_size_.K &&
|
||||
coord.h() < Shape::kColumn;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the vector starting at the current coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType const *get() const {
|
||||
TensorCoord coord = at();
|
||||
int64_t offset = coord.n();
|
||||
if (params_.is_convolution) {
|
||||
offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K;
|
||||
} else {
|
||||
offset += coord.h() * problem_size_.K;
|
||||
}
|
||||
|
||||
return reinterpret_cast<AccessType const *>(pointer_ +
|
||||
offset * sizeof_bits<Element>::value / 8);
|
||||
}
|
||||
|
||||
/// Increments to the next memory access
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() {
|
||||
++iteration_vector_;
|
||||
if (iteration_vector_ < kAccessesPerVector) {
|
||||
return *this;
|
||||
}
|
||||
iteration_vector_ = 0;
|
||||
|
||||
++iteration_contiguous_;
|
||||
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
||||
return *this;
|
||||
}
|
||||
iteration_contiguous_ = 0;
|
||||
|
||||
++iteration_strided_;
|
||||
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
||||
return *this;
|
||||
}
|
||||
iteration_strided_ = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Determines the filter size loaded by iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_load_size() {
|
||||
return kFilterSize;
|
||||
}
|
||||
|
||||
/// Determines whether the Implicit GEMM can execute the given problem.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Conv2dProblemSize const &problem_size) {
|
||||
|
||||
// check alignment constraint on iterator's contiguous dimension
|
||||
if (problem_size.K % AccessType::kElements) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
// check whether runtime filter size is same as templated filter size.
|
||||
if ((problem_size.R * problem_size.S) != Shape::kColumn) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
229
include/cutlass/conv/threadblock/depthwise_mma_base.h
Normal file
229
include/cutlass/conv/threadblock/depthwise_mma_base.h
Normal file
@ -0,0 +1,229 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a directconv threadblock-scoped Depthwise kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Policy object describing MmaTensorOp
|
||||
template <
|
||||
/// Warp-level GEMM operator (concept: gemm::warp::Mma)
|
||||
typename Operator_,
|
||||
/// Padding used for A operand in shared memory (concept: MatrixShape)
|
||||
typename SmemPaddingA_,
|
||||
/// Padding used for B operand in shared memory (concept: MatrixShape)
|
||||
typename SmemPaddingB_,
|
||||
///
|
||||
typename ThreadMapA_,
|
||||
///
|
||||
typename ThreadMapB_,
|
||||
/// Number of partitions of K dimension of GEMM
|
||||
int PartitionsK = 1>
|
||||
struct DepthwiseDirectConvMmaPolicy {
|
||||
/// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt)
|
||||
using Operator = Operator_;
|
||||
|
||||
/// Padding used for A operand in shared memory
|
||||
using SmemPaddingA = SmemPaddingA_;
|
||||
|
||||
/// Padding used for B operand in shared memory
|
||||
using SmemPaddingB = SmemPaddingB_;
|
||||
|
||||
using ThreadMapA = ThreadMapA_;
|
||||
using ThreadMapB = ThreadMapB_;
|
||||
|
||||
/// Number of partitions of K dimension
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class DepthwiseDirectConvMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = cutlass::gemm::
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// kWarpGemmIterations could be even and odd.
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<1, // Not determined at compile-time :(
|
||||
Shape::kN + Policy::SmemPaddingA::kRow>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<Policy::ThreadMapB::StorageShape::kStrided +
|
||||
Policy::SmemPaddingB::kRow, // filter_rs_size
|
||||
Policy::ThreadMapB::StorageShape::kContiguous +
|
||||
Policy::SmemPaddingB::kColumn>; // Tile N = 64?
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
// Let persistent B matrix in front of dynamic matrix A
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for A operand
|
||||
/// Not be determined at compile-time -- Just to get a Smem start address.
|
||||
AlignedBuffer<typename Operator::ElementA, 1> operand_A;
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; }
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; }
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DepthwiseDirectConvMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -44,11 +44,17 @@
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/warp/mma_depthwise_simt.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/mma_singlestage.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/conv/warp/mma_depthwise_simt.h"
|
||||
#include "cutlass/conv/threadblock/depthwise_mma_base.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h"
|
||||
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
@ -58,6 +64,95 @@ namespace cutlass {
|
||||
namespace conv {
|
||||
namespace threadblock {
|
||||
|
||||
namespace detail {
|
||||
//
|
||||
// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by
|
||||
// each partitions within warp.
|
||||
// The goal is for each thread's tile of elements to be as square as
|
||||
// possible for performance (4x4 will be faster than 2x8).
|
||||
template<int WarpShapeM, // The number of elements (1D) contained in the entire warp
|
||||
int WarpNumThreadsM> // The number of partitions within the warp
|
||||
struct SimtWarpShape {
|
||||
// kP * kQ * WarpNumThreadsM = WarpShapeM
|
||||
// If needed, enable more specializations.
|
||||
};
|
||||
template <>
|
||||
struct SimtWarpShape<4, 4> {
|
||||
static constexpr int kP = 1;
|
||||
static constexpr int kQ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SimtWarpShape<4, 2> {
|
||||
static constexpr int kP = 2;
|
||||
static constexpr int kQ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SimtWarpShape<4, 1> {
|
||||
static constexpr int kP = 2;
|
||||
static constexpr int kQ = 2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SimtWarpShape<8, 1> {
|
||||
static constexpr int kP = 2;
|
||||
static constexpr int kQ = 4;
|
||||
};
|
||||
template <>
|
||||
struct SimtWarpShape<8, 2> {
|
||||
static constexpr int kP = 2;
|
||||
static constexpr int kQ = 2;
|
||||
};
|
||||
template <>
|
||||
struct SimtWarpShape<8, 4> {
|
||||
static constexpr int kP = 1;
|
||||
static constexpr int kQ = 2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SimtWarpShape<16, 1> {
|
||||
static constexpr int kP = 4;
|
||||
static constexpr int kQ = 4;
|
||||
};
|
||||
template <>
|
||||
struct SimtWarpShape<16, 2> {
|
||||
static constexpr int kP = 2;
|
||||
static constexpr int kQ = 4;
|
||||
};
|
||||
template <>
|
||||
struct SimtWarpShape<16, 4> {
|
||||
static constexpr int kP = 2;
|
||||
static constexpr int kQ = 2;
|
||||
};
|
||||
|
||||
template <int WarpNumThreadsM>
|
||||
struct SimtWarpShape<25, WarpNumThreadsM> {
|
||||
static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads");
|
||||
static constexpr int kP = 5;
|
||||
static constexpr int kQ = 5;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SimtWarpShape<32, 1> {
|
||||
static constexpr int kP = 4;
|
||||
static constexpr int kQ = 8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SimtWarpShape<32, 2> {
|
||||
static constexpr int kP = 4;
|
||||
static constexpr int kQ = 4;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SimtWarpShape<32, 4> {
|
||||
static constexpr int kP = 2;
|
||||
static constexpr int kQ = 4;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator
|
||||
typename Shape,
|
||||
@ -114,6 +209,74 @@ struct DepthwiseMmaCoreWithLaneAccessSize;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator
|
||||
typename Shape,
|
||||
/// Shape of threadblock-scoped output tile
|
||||
typename ThreadBlockOutputShape,
|
||||
/// Shape of filter shape per threadblock
|
||||
typename FilterShape,
|
||||
/// Shape of warp-level matrix multiply operator
|
||||
typename WarpShape,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Element data type of A operand
|
||||
typename ElementA,
|
||||
/// Layout of operand A
|
||||
typename LayoutA,
|
||||
/// Element data type of B operand
|
||||
typename ElementB,
|
||||
/// Layout of operand B
|
||||
typename LayoutB,
|
||||
/// Data type of accumulator
|
||||
typename ElementC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC,
|
||||
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
|
||||
typename OperatorClass,
|
||||
/// Size of a warp-scoped per thread access
|
||||
int kLaneAccessSizeA_ = 0,
|
||||
/// Size of a warp-scoped per thread access
|
||||
int kLaneAccessSizeB_ = 0,
|
||||
/// Number of stages
|
||||
int Stages = 2,
|
||||
/// Operation performed by MMA
|
||||
typename Operator = typename platform::conditional<
|
||||
(platform::is_same<OperatorClass,
|
||||
cutlass::arch::OpClassTensorOp>::value) &&
|
||||
(platform::is_same<ElementA, int8_t>::value ||
|
||||
platform::is_same<ElementA, int4b_t>::value ||
|
||||
platform::is_same<ElementA, uint8_t>::value ||
|
||||
platform::is_same<ElementA, uint4b_t>::value),
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type,
|
||||
/// Iterator algo type
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
typename StrideShape = cutlass::MatrixShape<-1, -1>,
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
typename DilationShape = cutlass::MatrixShape<-1, -1>,
|
||||
/// Activation Shape loaded by threadblock
|
||||
typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA =
|
||||
cutlass::arch::CacheOperation::Global,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB =
|
||||
cutlass::arch::CacheOperation::Global,
|
||||
/// per-element transformation for elements of A
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// per-element transformation for elements of B
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
bool IsComplex = false // (is_complex<ElementA>::value || is_complex<ElementB>::value)
|
||||
>
|
||||
struct DepthwiseDirectConvMmaCoreWithLaneAccessSize;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator
|
||||
typename Shape,
|
||||
@ -332,6 +495,458 @@ struct DepthwiseMmaCoreWithLaneAccessSize<Shape_,
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: row-major
|
||||
/// Operator: simt class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of threadblock-scoped output tile (concept: TensorNHWCShape)
|
||||
typename ThreadBlockOutputShape_,
|
||||
/// Shape of filter shape per threadblock
|
||||
typename FilterShape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Size of a warp-scoped per thread access
|
||||
int kLaneAccessSizeA_,
|
||||
/// Number of stages
|
||||
int Stages_,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DepthwiseDirectConvMmaCoreWithLaneAccessSize<Shape_,
|
||||
ThreadBlockOutputShape_,
|
||||
FilterShape_,
|
||||
WarpShape_,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
ElementA_,
|
||||
layout::RowMajor,
|
||||
ElementB_,
|
||||
layout::ColumnMajor,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
kLaneAccessSizeA_,
|
||||
128,
|
||||
Stages_,
|
||||
Operator_> {
|
||||
using Shape = Shape_;
|
||||
using FilterShape = FilterShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
|
||||
static int const kLaneAccessSizeB = 128;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert( kLaneAccessSizeB > 0,
|
||||
"Size of a warp-scoped per thread access should be larger then ZERO" );
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = cutlass::gemm::GemmShape<
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
1
|
||||
>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) &&
|
||||
!(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
|
||||
);
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
// For Gmem load
|
||||
static int const kElementsPerAccessA = 128 / sizeof_bits<ElementA>::value;
|
||||
static int const kElementsPerAccessB = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::RowMajor;
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, 1>, // Set kStrided = 1 because activation shape is runtime value.
|
||||
kThreads,
|
||||
kElementsPerAccessA
|
||||
>;
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using SmemThreadMapA = IteratorThreadMapA;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv<
|
||||
MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
0,
|
||||
SmemThreadMapA, // was IteratorThreadMapA
|
||||
true // Dynamic iterations.
|
||||
>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, FilterShape::kCount>,
|
||||
kThreads,
|
||||
kElementsPerAccessB
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator B
|
||||
using SmemThreadMapB = IteratorThreadMapB;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv<
|
||||
MatrixShape<FilterShape::kCount, Shape::kN>,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
0,
|
||||
SmemThreadMapB, // was IteratorThreadMapB
|
||||
false // static iterations.
|
||||
>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
// Groups per threads
|
||||
// Fp32: 2 groups
|
||||
// Fp16: 2 groups
|
||||
static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4;
|
||||
// Define the warp-level op
|
||||
static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize);
|
||||
static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN;
|
||||
|
||||
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
|
||||
"WarpShape must be divisible by ThreadTile shape.");
|
||||
|
||||
// Get output P, Q per thread
|
||||
static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kP;
|
||||
static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kQ;
|
||||
|
||||
static const int LaneLayout = 1;
|
||||
static const int numElementsB = kLaneAccessSizeB / sizeof_bits<ElementB>::value;
|
||||
static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN);
|
||||
|
||||
// Define the output tile computed by each thread
|
||||
using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>;
|
||||
|
||||
// Fetch the channel with same access size
|
||||
static const int LaneM = LaneN;
|
||||
|
||||
// No paddings
|
||||
static int const kPaddingM = 0;
|
||||
static int const kPaddingN = 0;
|
||||
|
||||
static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN),
|
||||
"Padding must be divisible by Lane");
|
||||
|
||||
// these should have max of thread tile also
|
||||
using LaneMmaShape = cutlass::gemm::GemmShape<
|
||||
LaneM,
|
||||
LaneN,
|
||||
1>;
|
||||
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
|
||||
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
|
||||
LaneMmaShape
|
||||
>;
|
||||
|
||||
using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt<
|
||||
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
|
||||
ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<>
|
||||
ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
|
||||
ElementA, /// Data type of A elements
|
||||
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
|
||||
ElementB, /// Data type of B elements
|
||||
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
|
||||
ElementC, /// Element type of C matrix
|
||||
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
|
||||
Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy)
|
||||
>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
|
||||
MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
|
||||
IteratorThreadMapA,
|
||||
IteratorThreadMapB,
|
||||
WarpCount::kK
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: row-major
|
||||
/// Operator: simt class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of threadblock-scoped output tile (concept: TensorNHWCShape)
|
||||
typename ThreadBlockOutputShape_,
|
||||
/// Shape of filter shape per threadblock
|
||||
typename FilterShape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Size of a warp-scoped per thread access
|
||||
int kLaneAccessSizeA_,
|
||||
/// Number of stages
|
||||
int Stages_,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
typename StrideShape_,
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
typename DilationShape_,
|
||||
/// Activation Shape loaded by threadblock
|
||||
typename ActivationShape_>
|
||||
struct DepthwiseDirectConvMmaCoreWithLaneAccessSize<Shape_,
|
||||
ThreadBlockOutputShape_,
|
||||
FilterShape_,
|
||||
WarpShape_,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
ElementA_,
|
||||
layout::RowMajor,
|
||||
ElementB_,
|
||||
layout::ColumnMajor,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
arch::OpClassSimt,
|
||||
kLaneAccessSizeA_,
|
||||
128,
|
||||
Stages_,
|
||||
Operator_,
|
||||
IteratorAlgorithm::kFixedStrideDilation,
|
||||
StrideShape_,
|
||||
DilationShape_,
|
||||
ActivationShape_> {
|
||||
using Shape = Shape_;
|
||||
using FilterShape = FilterShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
using StrideShape = StrideShape_;
|
||||
using DilationShape = DilationShape_;
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
using ActivationShape = ActivationShape_;
|
||||
|
||||
static int const kLaneAccessSizeB = 128;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert( kLaneAccessSizeB > 0,
|
||||
"Size of a warp-scoped per thread access should be larger then ZERO" );
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = cutlass::gemm::GemmShape<
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
1
|
||||
>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) &&
|
||||
!(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
|
||||
);
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
// For Gmem load
|
||||
static int const kElementsPerAccessA = 128 / sizeof_bits<ElementA>::value;
|
||||
static int const kElementsPerAccessB = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::RowMajor;
|
||||
using SmemLayoutB = layout::RowMajor;
|
||||
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<ActivationShape::kC, ActivationShape::kNHW>,
|
||||
kThreads,
|
||||
kElementsPerAccessA
|
||||
>;
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using SmemThreadMapA = IteratorThreadMapA;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv<
|
||||
MatrixShape<ActivationShape::kNHW, ActivationShape::kC>,
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
0,
|
||||
SmemThreadMapA, // was IteratorThreadMapA
|
||||
false // static iterations.
|
||||
>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kN, FilterShape::kCount>,
|
||||
kThreads,
|
||||
kElementsPerAccessB
|
||||
>;
|
||||
|
||||
/// Transpose the ThreadMap of iterator B
|
||||
using SmemThreadMapB = IteratorThreadMapB;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv<
|
||||
MatrixShape<FilterShape::kCount, Shape::kN>,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
0,
|
||||
SmemThreadMapB, // was IteratorThreadMapB
|
||||
false // static iterations.
|
||||
>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
// Groups per threads
|
||||
// Fp32: 2 groups
|
||||
// Fp16: 2 groups
|
||||
static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4;
|
||||
// Define the warp-level op
|
||||
static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize);
|
||||
static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN;
|
||||
|
||||
static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kP;
|
||||
static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape<WarpShape::kM, WarpNumThreadsM>::kQ;
|
||||
|
||||
static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
|
||||
"WarpShape must be divisible by ThreadTile shape.");
|
||||
|
||||
static const int LaneLayout = 1;
|
||||
static const int numElementsB = kLaneAccessSizeB / sizeof_bits<ElementB>::value;
|
||||
static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN);
|
||||
|
||||
// Define the output tile computed by each thread
|
||||
using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>;
|
||||
|
||||
// Fetch the channel with same access size
|
||||
static const int LaneM = LaneN;
|
||||
|
||||
// No paddings
|
||||
static int const kPaddingM = 0;
|
||||
static int const kPaddingN = 0;
|
||||
|
||||
static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN),
|
||||
"Padding must be divisible by Lane");
|
||||
|
||||
// these should have max of thread tile also
|
||||
using LaneMmaShape = cutlass::gemm::GemmShape<
|
||||
LaneM,
|
||||
LaneN,
|
||||
1>;
|
||||
|
||||
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
||||
cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, // WarpShape
|
||||
cutlass::layout::RowMajorInterleaved<LaneLayout>, // LaneLayout
|
||||
LaneMmaShape
|
||||
>;
|
||||
|
||||
using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt<
|
||||
WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
|
||||
ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<>
|
||||
ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
|
||||
ElementA, /// Data type of A elements
|
||||
SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout)
|
||||
ElementB, /// Data type of B elements
|
||||
SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout)
|
||||
ElementC, /// Element type of C matrix
|
||||
LayoutC, /// Layout of C matrix (concept: MatrixLayout)
|
||||
Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy)
|
||||
IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type
|
||||
StrideShape, /// Stride ( MatrixShape<Height, Width> )
|
||||
DilationShape, /// Dilation ( MatrixShape<Height, Width> )
|
||||
ActivationShape /// Activation Shape loaded by threadblock
|
||||
>;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<
|
||||
MmaWarpSimt,
|
||||
MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
|
||||
MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
|
||||
IteratorThreadMapA,
|
||||
IteratorThreadMapB,
|
||||
WarpCount::kK
|
||||
>;
|
||||
};
|
||||
} // namespace threadblock
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -165,7 +165,29 @@ struct StridedDgradIdentityThreadblockSwizzle :
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Threadblock swizzling function for GEMMs
|
||||
template <int N = 1, int Output_N = 1, int Output_P = 1, int Output_Q = 1>
|
||||
struct DepthwiseDirect2dConvIdentityThreadblockSwizzle
|
||||
: public gemm::threadblock::GemmIdentityThreadblockSwizzle<N> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvIdentityThreadblockSwizzle() {}
|
||||
|
||||
/// Returns the shape of the problem in units of logical tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||
gemm::GemmCoord tile_size,
|
||||
int split_k_slices) const {
|
||||
|
||||
gemm::GemmCoord implicit_gemm_problem_size =
|
||||
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size);
|
||||
|
||||
return gemm::GemmCoord(1,
|
||||
(implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(),
|
||||
split_k_slices);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -42,6 +42,9 @@
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/gemm/thread/mma.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/thread/depthwise_mma.h"
|
||||
|
||||
|
||||
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_simt_policy.h"
|
||||
@ -91,7 +94,7 @@ class MmaDepthwiseSimt
|
||||
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_; // < 64, 16 , 8>
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
@ -156,8 +159,223 @@ public:
|
||||
MmaDepthwiseSimt():Base() {}
|
||||
};
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
|
||||
typename FilterShape_,
|
||||
/// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<>
|
||||
typename ThreadOutputShape_,
|
||||
/// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
|
||||
typename ThreadBlockOutputShape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
|
||||
typename Policy_,
|
||||
/// Iterator algo type
|
||||
conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic,
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
typename StrideShape_ = cutlass::MatrixShape<-1, -1>,
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
typename DilationShape_ = cutlass::MatrixShape<-1, -1>,
|
||||
/// Activation Shape loaded by threadblock
|
||||
typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK = 1,
|
||||
/// Complex transformation on operand A
|
||||
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||
/// Complex transformation on operand B
|
||||
ComplexTransform TransformB = ComplexTransform::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaDepthwiseDirectConvSimt {
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Shape of filter shape per threadblock - concept: gemm::GemmShape<Depth, Height, Width>
|
||||
using FilterShape = FilterShape_;
|
||||
|
||||
/// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<>
|
||||
using ThreadOutputShape = ThreadOutputShape_;
|
||||
|
||||
/// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<>
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Iterator algo type
|
||||
static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_;
|
||||
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
using StrideShape = StrideShape_;
|
||||
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
using DilationShape = DilationShape_;
|
||||
|
||||
/// Activation Shape loaded by threadblock
|
||||
using ActivationShape = ActivationShape_;
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
|
||||
/// Hard-coded for now
|
||||
using ArchTag = arch::Sm50;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value ||
|
||||
platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) &&
|
||||
platform::is_same< ElementA, int8_t >::value &&
|
||||
platform::is_same< ElementB, int8_t >::value;
|
||||
|
||||
using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type;
|
||||
|
||||
/// Thread-level matrix multiply accumulate operator
|
||||
using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<
|
||||
cutlass::gemm::GemmShape<
|
||||
Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread
|
||||
Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread
|
||||
1>,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC,
|
||||
arch::OpMultiplyAdd,
|
||||
dp4a_type
|
||||
>;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename ThreadMma::ArchMmaOperator;
|
||||
|
||||
/// Indicates math operator
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
/// Shape of the underlying instruction
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>;
|
||||
|
||||
public:
|
||||
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>, // <output tile=(P*Q), output channels> per warp
|
||||
FilterShape,
|
||||
ThreadOutputShape,
|
||||
ThreadBlockOutputShape,
|
||||
cutlass::gemm::Operand::kA,
|
||||
ElementA,
|
||||
Policy,
|
||||
IteratorAlgorithm,
|
||||
StrideShape,
|
||||
DilationShape,
|
||||
ActivationShape,
|
||||
PartitionsK,
|
||||
Shape::kK
|
||||
>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA = FragmentA;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator<
|
||||
MatrixShape<1, Shape::kN>,
|
||||
cutlass::gemm::Operand::kB,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
Policy,
|
||||
PartitionsK,
|
||||
Shape::kK
|
||||
>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentB = FragmentB;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator<
|
||||
MatrixShape<Shape::kM, Shape::kN>,
|
||||
cutlass::gemm::Operand::kC,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
Policy
|
||||
>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename ThreadMma::FragmentC;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaDepthwiseDirectConvSimt() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
FragmentC &d,
|
||||
FragmentA a,
|
||||
FragmentB b,
|
||||
FragmentC const &c, int group_idx = 0) const {
|
||||
|
||||
ThreadMma mma;
|
||||
|
||||
mma(d, a, b, c);
|
||||
}
|
||||
|
||||
/// Transform the mma operands to the required types
|
||||
CUTLASS_DEVICE
|
||||
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
|
||||
FragmentA const &A, FragmentB const &B) const {
|
||||
//TODO: Implement this
|
||||
dst_A = A;
|
||||
dst_B = B;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -40,6 +40,8 @@
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/conv/convolution.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
@ -250,6 +252,611 @@ private:
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Size of filter (concept: gemm::GemmShape<Depth, Height, Width>)
|
||||
typename FilterShape_,
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename ThreadOutputShape_,
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename ThreadBlockOutputShape_,
|
||||
/// Operand identity
|
||||
cutlass::gemm::Operand Operand,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
|
||||
typename Policy_,
|
||||
/// Iterator algo type
|
||||
conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
typename StrideShape = cutlass::MatrixShape<-1, -1>,
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
typename DilationShape = cutlass::MatrixShape<-1, -1>,
|
||||
/// Activation Shape loaded by threadblock
|
||||
typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>,
|
||||
/// Number of partitions along K dimension - used in sliced-K
|
||||
int PartitionsK = 1,
|
||||
/// Group Size along kPartition - used in sliced-K
|
||||
int PartitionGroupSize = 1>
|
||||
class DepthwiseDirect2dConvSimtTileIterator;
|
||||
|
||||
|
||||
/// Specialization for A operands of row-major layouts
|
||||
///
|
||||
/// Concept: MutableRandomAccessContiguousTileIteratorConcept
|
||||
///
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Size of filter (concept: gemm::GemmShape<Depth, Height, Width>)
|
||||
typename FilterShape_,
|
||||
/// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename ThreadOutputShape_,
|
||||
/// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename ThreadBlockOutputShape_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
|
||||
typename Policy_,
|
||||
/// Iterator algo type
|
||||
conv::IteratorAlgorithm IteratorAlgorithm,
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
typename StrideShape,
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
typename DilationShape,
|
||||
/// Activation Shape loaded by threadblock
|
||||
typename ActivationShape,
|
||||
/// Number of partitions along K dimension - used in sliced-K
|
||||
int PartitionsK,
|
||||
/// Group Size along kPartition - used in sliced-K
|
||||
int PartitionGroupSize>
|
||||
class DepthwiseDirect2dConvSimtTileIterator<Shape_,
|
||||
FilterShape_,
|
||||
ThreadOutputShape_,
|
||||
ThreadBlockOutputShape_,
|
||||
cutlass::gemm::Operand::kA,
|
||||
Element_,
|
||||
Policy_,
|
||||
IteratorAlgorithm,
|
||||
StrideShape,
|
||||
DilationShape,
|
||||
ActivationShape,
|
||||
PartitionsK,
|
||||
PartitionGroupSize> {
|
||||
public:
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Shape of filter (concept: gemm::GemmShape<Depth, Height, Width>)
|
||||
using FilterShape = FilterShape_;
|
||||
|
||||
/// Shape of tile to load (concept: TensorNHWC)
|
||||
using ThreadOutputShape = ThreadOutputShape_;
|
||||
|
||||
/// Shape of tile to load (concept: TensorNHWC)
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
|
||||
/// Operand tag
|
||||
static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of policy
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// Decomposition of elements among threads
|
||||
using Policy = Policy_;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
|
||||
"The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
|
||||
|
||||
static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
|
||||
static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
|
||||
static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
|
||||
static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
|
||||
|
||||
// Thread-level shape of a fragment
|
||||
using ThreadShape = MatrixShape<
|
||||
ThreadOutputShape::kNHW, // Output tile shape Computed by current threads
|
||||
ThreadOutputShape::kC
|
||||
>;
|
||||
|
||||
static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),
|
||||
"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
|
||||
|
||||
/// Number of individual loads
|
||||
using Iterations = MatrixShape<
|
||||
ThreadShape::kRow,
|
||||
ThreadShape::kColumn / Policy::LaneMmaShape::kN
|
||||
>;
|
||||
|
||||
using ThreadTileCount = MatrixShape<
|
||||
ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
|
||||
ThreadBlockOutputShape::kW / ThreadOutputShape::kW
|
||||
>;
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, ThreadShape::kCount>;
|
||||
|
||||
protected:
|
||||
|
||||
/// Internal reference
|
||||
cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kN>, layout::RowMajor> ref_;
|
||||
|
||||
int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn];
|
||||
int iterator_r_;
|
||||
int iterator_s_;
|
||||
int iterator_offset_;
|
||||
|
||||
int inc_next_s_ ;
|
||||
int inc_next_r_ ;
|
||||
|
||||
MatrixCoord lane_offset_;
|
||||
public:
|
||||
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator() { }
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator(
|
||||
TensorRef ref,
|
||||
int lane_id
|
||||
) {
|
||||
|
||||
// compute offset based on thread ID and lane layout
|
||||
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
|
||||
|
||||
// Set channel offset
|
||||
lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN);
|
||||
|
||||
ref.add_coord_offset(lane_offset_);
|
||||
|
||||
ref_.reset(reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(ref.data()),
|
||||
ref.stride(0) / Policy::LaneMmaShape::kN);
|
||||
|
||||
iterator_r_ = 0;
|
||||
iterator_s_ = 0;
|
||||
iterator_offset_ = 0;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) {
|
||||
ref_.add_pointer_offset(offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
template<typename Params>
|
||||
CUTLASS_HOST_DEVICE
|
||||
void setup_initial_status(Params const& params) {
|
||||
|
||||
inc_next_s_ = params.inc_next[0];
|
||||
inc_next_r_ = params.inc_next[1];
|
||||
|
||||
// Get base HW offset of current threads
|
||||
int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC);
|
||||
int base_p_ =
|
||||
(threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH;
|
||||
int base_q_ =
|
||||
(threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < ThreadOutputShape::kH; ++p) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int q = 0; q < ThreadOutputShape::kW; ++q) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int col = 0; col < Iterations::kColumn; ++col) {
|
||||
int base_w = (base_q_ + q) * params.stride[0];
|
||||
int base_h = (base_p_ + p) * params.stride[1];
|
||||
|
||||
int offset = base_h * params.activation_tile_w + base_w;
|
||||
activation_offset[p][q][col] = offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
|
||||
// Set warp row and col start
|
||||
lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()});
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance(int32_t pointer_offset) {
|
||||
ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN);
|
||||
iterator_s_ = 0;
|
||||
iterator_r_ = 0;
|
||||
iterator_offset_ = 0;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator &operator++() {
|
||||
++iterator_s_;
|
||||
if (iterator_s_ < FilterShape::kColumn) {
|
||||
iterator_offset_ += inc_next_s_;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
iterator_s_ = 0;
|
||||
|
||||
++iterator_r_;
|
||||
if (iterator_r_ < FilterShape::kRow) {
|
||||
iterator_offset_ += inc_next_r_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
iterator_r_ = 0;
|
||||
iterator_offset_ = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator & operator--() {
|
||||
// Do nothing
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator. (vector loads)
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
|
||||
|
||||
Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
|
||||
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < ThreadOutputShape::kH; ++p) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int q = 0; q < ThreadOutputShape::kW; ++q) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Iterations::kColumn; ++n) {
|
||||
void const *ptr = ref_.data() +
|
||||
ref_.offset({activation_offset[p][q][n] + (iterator_offset_),
|
||||
n * Policy::WarpShape::kColumn}) +
|
||||
pointer_offset / Policy::LaneMmaShape::kN;
|
||||
arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory at the location pointed to by the iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
|
||||
// Do nothing at present.
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory at the location pointed to by the iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store(Fragment const &frag, Index pointer_offset) const {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
// no operation here
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for A operands of row-major layouts
|
||||
///
|
||||
/// Concept: MutableRandomAccessContiguousTileIteratorConcept
|
||||
///
|
||||
template <
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Size of filter (concept: gemm::GemmShape<Depth, Height, Width>)
|
||||
typename FilterShape_,
|
||||
/// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename ThreadOutputShape_,
|
||||
/// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename ThreadBlockOutputShape_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
/// Shape of the warp in units of thread (concept: MmaSimtPolicy)
|
||||
typename Policy_,
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
typename StrideShape_,
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
typename DilationShape_,
|
||||
/// Activation Shape loaded by threadblock
|
||||
typename ActivationShape_,
|
||||
/// Number of partitions along K dimension - used in sliced-K
|
||||
int PartitionsK,
|
||||
/// Group Size along kPartition - used in sliced-K
|
||||
int PartitionGroupSize>
|
||||
class DepthwiseDirect2dConvSimtTileIterator<Shape_,
|
||||
FilterShape_,
|
||||
ThreadOutputShape_,
|
||||
ThreadBlockOutputShape_,
|
||||
cutlass::gemm::Operand::kA,
|
||||
Element_,
|
||||
Policy_,
|
||||
IteratorAlgorithm::kFixedStrideDilation,
|
||||
StrideShape_,
|
||||
DilationShape_,
|
||||
ActivationShape_,
|
||||
PartitionsK,
|
||||
PartitionGroupSize> {
|
||||
public:
|
||||
/// Shape of tile to load (concept: MatrixShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Shape of filter (concept: gemm::GemmShape<Depth, Height, Width>)
|
||||
using FilterShape = FilterShape_;
|
||||
|
||||
/// Shape of tile to load (concept: TensorNHWC)
|
||||
using ThreadOutputShape = ThreadOutputShape_;
|
||||
|
||||
/// Shape of tile to load (concept: TensorNHWC)
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
|
||||
/// Stride ( MatrixShape<Height, Width> )
|
||||
using StrideShape = StrideShape_;
|
||||
|
||||
/// Dilation ( MatrixShape<Height, Width> )
|
||||
using DilationShape = DilationShape_;
|
||||
|
||||
/// Activation Shape loaded by threadblock
|
||||
using ActivationShape = ActivationShape_;
|
||||
|
||||
/// Operand tag
|
||||
static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
/// Layout of policy
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// Decomposition of elements among threads
|
||||
using Policy = Policy_;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
|
||||
/// Index type
|
||||
using Index = typename TensorRef::Index;
|
||||
|
||||
/// Long Index type
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
/// Coordinate for an element in the tensor
|
||||
using TensorCoord = typename TensorRef::TensorCoord;
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
|
||||
"The warp-level GEMM M size must be divisible by the number of threads arranged "
|
||||
"along the M dimension.");
|
||||
|
||||
static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
|
||||
static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
|
||||
static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
|
||||
static_assert(Shape::kRow / Policy::WarpShape::kRow > 0,
|
||||
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
|
||||
|
||||
// Activations loaded by threadblock
|
||||
static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow +
|
||||
(FilterShape::kRow - 1) * DilationShape::kRow + 1;
|
||||
|
||||
static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn +
|
||||
(FilterShape::kColumn - 1) * DilationShape::kColumn + 1;
|
||||
|
||||
using ThreadActivationShape = cutlass::conv::
|
||||
TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>;
|
||||
|
||||
// Thread-level shape of a fragment
|
||||
using ThreadShape =
|
||||
MatrixShape<ThreadOutputShape::kNHW,
|
||||
ThreadOutputShape::kC>;
|
||||
|
||||
static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),
|
||||
"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
|
||||
|
||||
/// Number of individual loads
|
||||
using Iterations =
|
||||
MatrixShape<ThreadShape::kRow, ThreadShape::kColumn / Policy::LaneMmaShape::kN>;
|
||||
|
||||
using ThreadTileCount = MatrixShape<ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
|
||||
ThreadBlockOutputShape::kW / ThreadOutputShape::kW>;
|
||||
|
||||
/// Fragment object holding a thread's part of a tile
|
||||
using Fragment = Array<Element, ThreadShape::kCount>;
|
||||
|
||||
protected:
|
||||
/// Internal reference
|
||||
cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kN>, layout::RowMajor> ref_;
|
||||
|
||||
Array<Element, Policy::LaneMmaShape::kN>
|
||||
activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn];
|
||||
int iterator_r_;
|
||||
int iterator_s_;
|
||||
|
||||
|
||||
MatrixCoord lane_offset_;
|
||||
|
||||
public:
|
||||
/// Default ctor constructs null iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator() {}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) {
|
||||
// compute offset based on thread ID and lane layout
|
||||
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
|
||||
|
||||
// Set channel offset
|
||||
lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN);
|
||||
|
||||
ref.add_coord_offset(lane_offset_);
|
||||
|
||||
ref_.reset(reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(ref.data()),
|
||||
ref.stride(0) / Policy::LaneMmaShape::kN);
|
||||
|
||||
iterator_r_ = 0;
|
||||
iterator_s_ = 0;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset to internal pointer(s) to advance through memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) {
|
||||
ref_.add_pointer_offset(offset);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
template <typename Params>
|
||||
CUTLASS_HOST_DEVICE void setup_initial_status(
|
||||
Params const ¶ms) {
|
||||
|
||||
// Get base HW offset of current threads
|
||||
int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC);
|
||||
int base_h =
|
||||
(threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow;
|
||||
int base_w =
|
||||
(threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < ThreadActivationShape::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < ThreadActivationShape::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int col = 0; col < Iterations::kColumn; ++col) {
|
||||
int offset = (base_h + h) * ActivationShape::kW + (base_w + w);
|
||||
|
||||
void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn});
|
||||
arch::shared_load(activation[h][w][col], ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
|
||||
// Set warp row and col start
|
||||
lane_offset_ =
|
||||
MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()});
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
void advance(int32_t pointer_offset) {
|
||||
ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN);
|
||||
iterator_s_ = 0;
|
||||
iterator_r_ = 0;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator &operator++() {
|
||||
++iterator_s_;
|
||||
if (iterator_s_ < FilterShape::kColumn) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iterator_s_ = 0;
|
||||
|
||||
++iterator_r_;
|
||||
if (iterator_r_ < FilterShape::kRow) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
iterator_r_ = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator along the advance dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
DepthwiseDirect2dConvSimtTileIterator &operator--() {
|
||||
// Do nothing
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator. (vector loads)
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
|
||||
Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
|
||||
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < ThreadOutputShape::kH; ++p) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int q = 0; q < ThreadOutputShape::kW; ++q) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Iterations::kColumn; ++n) {
|
||||
const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow;
|
||||
const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn;
|
||||
|
||||
dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory at the location pointed to by the iterator.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); }
|
||||
|
||||
/// Stores a fragment to memory at the location pointed to by the iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
|
||||
// Do nothing at present.
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory at the location pointed to by the iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store(Fragment const &frag, Index pointer_offset) const {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void set_kgroup_index(int k_group) {
|
||||
// no operation here
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -100,12 +100,21 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from some other Coord
|
||||
template <int R, typename I, typename L>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(Coord<R, I, L> other) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = other[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a slice of the Coord which may be larger or smaller in rank
|
||||
/// than this.
|
||||
template <int Slice>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Slice> slice(int start = 0, Index identity = 0) const {
|
||||
Coord<Slice> result;
|
||||
Coord<Slice, Index, LongIndex> slice(int start = 0, Index identity = 0) const {
|
||||
Coord<Slice, Index, LongIndex> result;
|
||||
for (int i = 0; i < Slice; ++i) {
|
||||
if (i + start < kRank) {
|
||||
result[i] = idx[i + start];
|
||||
|
||||
@ -59,7 +59,9 @@ inline std::ostream &operator<<(std::ostream &out, dim3 d) {
|
||||
|
||||
/// Output operator for CUDA built-in error type
|
||||
inline std::ostream &operator<<(std::ostream &out, cudaError_t error) {
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
return out << cudaGetErrorString(error);
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -252,8 +254,9 @@ namespace conv {
|
||||
inline
|
||||
std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) {
|
||||
out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl
|
||||
<< "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl
|
||||
<< "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl
|
||||
<< "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl
|
||||
<< "groups: (" << problem.groups << ")" << std::endl
|
||||
<< "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl
|
||||
<< "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl
|
||||
<< "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl
|
||||
|
||||
@ -57,6 +57,23 @@ void Kernel(typename Operator::Params params) {
|
||||
op(params, *shared_storage);
|
||||
}
|
||||
|
||||
|
||||
/// Generic CUTLASS kernel template.
|
||||
template <typename Operator>
|
||||
__global__
|
||||
void Kernel2(typename Operator::Params params) {
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
typename Operator::SharedStorage *shared_storage =
|
||||
reinterpret_cast<typename Operator::SharedStorage *>(SharedStorageBase);
|
||||
|
||||
Operator::invoke(params, *shared_storage);
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
} /// namespace cutlass
|
||||
|
||||
|
||||
@ -104,15 +104,15 @@ struct Identity {
|
||||
template <typename T, int N>
|
||||
struct Identity<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
return rhs;
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
return value;
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -183,7 +183,7 @@ struct LeakyReLU {
|
||||
Params():
|
||||
LinearCombinationGenericParams<T>(),
|
||||
leaky_alpha(T(1)) {}
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
T alpha,
|
||||
@ -228,21 +228,21 @@ struct LeakyReLU<Array<T, N> > {
|
||||
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, T const & alpha_recip) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value, T const & alpha_recip) const {
|
||||
Array<T, N> y;
|
||||
LeakyReLU<T> leaky_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < int(rhs.size()); ++i) {
|
||||
y[i] = leaky_op(rhs[i], alpha_recip);
|
||||
for (int i = 0; i < int(value.size()); ++i) {
|
||||
y[i] = leaky_op(value[i], alpha_recip);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs, params_.leaky_alpha);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value, params_.leaky_alpha);
|
||||
}
|
||||
};
|
||||
|
||||
@ -265,13 +265,13 @@ struct Tanh {
|
||||
template <typename T, int N>
|
||||
struct Tanh<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
Tanh<T> tanh_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = tanh_op(rhs[i]);
|
||||
y[i] = tanh_op(value[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
@ -280,8 +280,8 @@ struct Tanh<Array<T, N> > {
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -299,8 +299,8 @@ struct Tanh<Array<half_t, N>> {
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -323,13 +323,13 @@ struct Sigmoid {
|
||||
template <typename T, int N>
|
||||
struct Sigmoid<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
Sigmoid<T> sigmoid_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = sigmoid_op(rhs[i]);
|
||||
y[i] = sigmoid_op(value[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
@ -338,8 +338,8 @@ struct Sigmoid<Array<T, N> > {
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -398,17 +398,17 @@ struct SiLu {
|
||||
template <typename T, int N>
|
||||
struct SiLu<Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Sigmoid<Array<T, N>> sigmoid_op;
|
||||
multiplies<Array<T, N>> mul;
|
||||
return mul(rhs, sigmoid_op(rhs));
|
||||
return mul(value, sigmoid_op(value));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -458,13 +458,13 @@ struct HardSwish<float> {
|
||||
template <typename T, int N>
|
||||
struct HardSwish<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
HardSwish<T> hardswish_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = hardswish_op(rhs[i]);
|
||||
y[i] = hardswish_op(value[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
@ -483,13 +483,13 @@ struct HardSwish<Array<half_t, N> > {
|
||||
using T = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
minimum<Array<T, N> > mn;
|
||||
maximum<Array<T, N> > mx;
|
||||
multiplies<Array<T, N> > mul;
|
||||
plus<Array<T, N> > add;
|
||||
|
||||
return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f));
|
||||
return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(0.16666667f));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
@ -561,13 +561,13 @@ struct GELU<double> {
|
||||
template <typename T, int N>
|
||||
struct GELU<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
GELU<T> gelu_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = gelu_op(rhs[i]);
|
||||
y[i] = gelu_op(value[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
@ -576,8 +576,8 @@ struct GELU<Array<T, N> > {
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -601,7 +601,6 @@ struct GELU_taylor {
|
||||
T operator()(T const &scalar, Params const ¶ms_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <int N>
|
||||
@ -632,8 +631,8 @@ struct GELU_taylor<Array<half_t, N> > {
|
||||
using Params = LinearCombinationGenericParams<half_t>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<half_t, N> operator()(Array<half_t, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -641,13 +640,13 @@ template <typename T, int N>
|
||||
struct GELU_taylor<Array<T, N> > {
|
||||
static const bool kIsHeavy=true;
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
GELU_taylor<T> gelu_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = gelu_op(rhs[i]);
|
||||
y[i] = gelu_op(value[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
@ -656,8 +655,8 @@ struct GELU_taylor<Array<T, N> > {
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
||||
return this->operator()(rhs);
|
||||
Array<T, N> operator()(Array<T, N> const &value, Params const ¶ms_) const {
|
||||
return this->operator()(value);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -78,6 +78,9 @@ public:
|
||||
using ElementwiseOp = ElementwiseOp_;
|
||||
using BinaryOp = BinaryOp_;
|
||||
|
||||
// Indicates that this epilogue applies only one binary operation
|
||||
static bool const kIsSingleSource = true;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
@ -223,6 +223,9 @@ public:
|
||||
using ElementwiseOp = ReLu<ElementCompute>;
|
||||
using BinaryOp = plus<ElementCompute>;
|
||||
|
||||
// Indicates that this epilogue applies only one binary operation
|
||||
static bool const kIsSingleSource = true;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue functor specialized for residual blocks in deep neural network.
|
||||
\brief Epilogue functor specialized for residual blocks in deep neural networks.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -45,14 +45,24 @@ namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
||||
namespace detail {
|
||||
|
||||
/// Dummy class used to designate that the second binary operator in the epilogue is unsued
|
||||
template <typename T>
|
||||
class NoOp {};
|
||||
|
||||
}
|
||||
|
||||
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
|
||||
template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp_,
|
||||
template <typename T> class UnaryOp_>
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
template <typename T> class BinaryOp2_ = detail::NoOp>
|
||||
class LinearCombinationResidualBlock {
|
||||
public:
|
||||
static bool const kIsSingleSource = false;
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementC = ElementC_;
|
||||
@ -62,7 +72,130 @@ public:
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp = BinaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
|
||||
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
||||
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
using ElementZ = ElementOutput_;
|
||||
using ElementT = ElementZ;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
static bool const kStoreZ = true;
|
||||
static bool const kStoreT = false;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales residual input
|
||||
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute alpha, ElementCompute beta)
|
||||
: alpha(alpha), beta(beta) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
||||
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
bool skip_elementwise_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor from Params
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationResidualBlock(Params const ¶ms) {
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
skip_elementwise_ = false;
|
||||
}
|
||||
|
||||
/// The "source" tensor corresponds to the residual input
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const { return true; }
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
skip_elementwise_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||
FragmentC const &residual1, FragmentC const &residual2,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp1 binary_op1;
|
||||
BinaryOp2 binary_op2;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_residual1 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
|
||||
FragmentCompute tmp_residual2 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
|
||||
|
||||
FragmentCompute z =
|
||||
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
|
||||
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Should never be called
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
||||
FragmentCompute const &) const {}
|
||||
};
|
||||
|
||||
/// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
||||
template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_>
|
||||
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
|
||||
ElementCompute_, ElementC_, ElementsPerAccess,
|
||||
ActivationOp_, BinaryOp1_, UnaryOp_,
|
||||
detail::NoOp> {
|
||||
public:
|
||||
static bool const kIsSingleSource = true;
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp = BinaryOp1_<Array<ElementCompute, kCount>>;
|
||||
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
|
||||
@ -1,197 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Epilogue functor specialized for residual blocks in deep neural network.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual))
|
||||
// or form UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
|
||||
template <typename ElementOutput_, typename ElementAccumulator_,
|
||||
typename ElementCompute_, typename ElementC_, int ElementsPerAccess,
|
||||
template <typename T> class ActivationOp_,
|
||||
template <typename T> class BinaryOp1_,
|
||||
template <typename T> class UnaryOp_,
|
||||
template <typename T> class BinaryOp2_=BinaryOp1_>
|
||||
class LinearCombinationResidualBlockV2 {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementC = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
|
||||
using UnaryOp = UnaryOp_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp1 = BinaryOp1_<Array<ElementCompute, kCount>>;
|
||||
using BinaryOp2 = BinaryOp2_<Array<ElementCompute, kCount>>;
|
||||
using ActivationOp = ActivationOp_<Array<ElementCompute, kCount>>;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementC, kElementsPerAccess>;
|
||||
using FragmentOutput = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
using ElementZ = ElementOutput_;
|
||||
using ElementT = ElementZ;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
static bool const kIsHeavy = true;
|
||||
static bool const kStoreZ = true;
|
||||
static bool const kStoreT = false;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales residual input
|
||||
ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute alpha, ElementCompute beta)
|
||||
: alpha(alpha), beta(beta) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
|
||||
: alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
bool skip_elementwise_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor from Params
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationResidualBlockV2(Params const ¶ms) {
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
skip_elementwise_ = false;
|
||||
}
|
||||
|
||||
/// The "source" tensor corresponds to the residual input
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const { return true; }
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
/// IMPORTANT: Split-k is supported only when ActivationOp is Identity.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
|
||||
if (k_partition != k_partition_count - 1) {
|
||||
skip_elementwise_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual))
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||
FragmentC const &residual,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp1 binary_op;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_residual =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual);
|
||||
|
||||
FragmentCompute z =
|
||||
binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual);
|
||||
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2))
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB,
|
||||
FragmentC const &residual1, FragmentC const &residual2,
|
||||
FragmentCompute const &bias) const {
|
||||
UnaryOp unary_op;
|
||||
BinaryOp1 binary_op1;
|
||||
BinaryOp2 binary_op2;
|
||||
ActivationOp activation;
|
||||
|
||||
FragmentCompute tmp_Accum =
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
||||
FragmentCompute tmp_residual1 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual1);
|
||||
FragmentCompute tmp_residual2 =
|
||||
NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(residual2);
|
||||
|
||||
FragmentCompute z =
|
||||
binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2);
|
||||
FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z);
|
||||
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> convert_z;
|
||||
frag_Z = convert_z(result_Z);
|
||||
}
|
||||
|
||||
/// Should never be called
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &,
|
||||
FragmentCompute const &) const {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -58,12 +58,16 @@
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_simt.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_simt.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_depthwise.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
@ -314,6 +318,100 @@ struct DefaultEpilogueSimtAffineRankN {
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for SimtOps.
|
||||
template <typename Shape_, // ThreadBlock Shape
|
||||
typename WarpMmaSimt_, // mma_depthwise_simt
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess_,
|
||||
typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>,
|
||||
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> >
|
||||
struct DefaultDirectConvEpilogueSimt {
|
||||
using Shape = Shape_;
|
||||
using WarpMmaSimt = WarpMmaSimt_;
|
||||
using WarpShape = typename WarpMmaSimt::Shape;
|
||||
using OutputOp = OutputOp_;
|
||||
using ThreadOutputShape = ThreadOutputShape_;
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess_;
|
||||
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaSimt::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaSimt::ElementC;
|
||||
|
||||
/// Number of threads total
|
||||
using WarpCount = gemm::GemmShape<
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN
|
||||
>;
|
||||
|
||||
static int const kWarpSize = cutlass::gemm::warp::WarpSize<arch::OpClassSimt>::value;
|
||||
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<ThreadBlockOutputShape::kC, ThreadBlockOutputShape::kNHW>,
|
||||
kThreads,
|
||||
kElementsPerAccess
|
||||
>;
|
||||
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ThreadOutputShape,
|
||||
ThreadBlockOutputShape
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
||||
typename WarpMmaSimt::Shape,
|
||||
typename WarpMmaSimt::ThreadMma,
|
||||
layout::RowMajor,
|
||||
typename WarpMmaSimt::Policy
|
||||
>;
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<
|
||||
typename WarpMmaSimt::Shape,
|
||||
ThreadOutputShape,
|
||||
ThreadBlockOutputShape,
|
||||
typename WarpMmaSimt::ThreadMma,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
typename WarpMmaSimt::Policy
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLiner<
|
||||
OutputTileThreadMap,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = typename WarpTileIterator::Padding;
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::EpilogueDepthwise<
|
||||
Shape,
|
||||
ThreadOutputShape,
|
||||
ThreadBlockOutputShape,
|
||||
WarpMmaSimt,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
@ -293,6 +293,141 @@ struct DefaultIteratorsTensorOp<
|
||||
|
||||
static int const kFragmentsPerIteration = 1;
|
||||
};
|
||||
|
||||
/// Partial specialization for float_e4m3_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
|
||||
/// Threadblock::kN = 256 still has bank conflicts.
|
||||
template <
|
||||
int ElementsPerAccess,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<
|
||||
cutlass::float_e4m3_t,
|
||||
float,
|
||||
ElementsPerAccess,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
using ElementOutput = cutlass::float_e4m3_t;
|
||||
|
||||
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),
|
||||
"ElementsPerAccess needs to be 16 or 8.");
|
||||
|
||||
using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
float,
|
||||
32,
|
||||
cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementsPerAccess,
|
||||
8
|
||||
>;
|
||||
|
||||
using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
float,
|
||||
layout::RowMajor
|
||||
>;
|
||||
|
||||
using WarpTileIterator = typename platform::conditional<
|
||||
(ThreadblockShape::kN == 256),
|
||||
WarpTileIteratorNotMixed,
|
||||
WarpTileIteratorMixed>::type;
|
||||
|
||||
using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||
ThreadMap,
|
||||
float,
|
||||
32,
|
||||
cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementsPerAccess,
|
||||
8
|
||||
>;
|
||||
|
||||
using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
ThreadMap,
|
||||
float
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = typename platform::conditional<
|
||||
(ThreadblockShape::kN == 256),
|
||||
SharedLoadIteratorNotMixed,
|
||||
SharedLoadIteratorMixed>::type;
|
||||
|
||||
static int const kFragmentsPerIteration = 1;
|
||||
};
|
||||
|
||||
/// Partial specialization for float_e5m2_t <= float x 16/8 epilogues avoids shared memory bank conflicts.
|
||||
/// Threadblock::kN = 256 still has bank conflicts.
|
||||
template <
|
||||
int ElementsPerAccess,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<
|
||||
cutlass::float_e5m2_t,
|
||||
float,
|
||||
ElementsPerAccess,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
using ElementOutput = cutlass::float_e5m2_t;
|
||||
|
||||
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),
|
||||
"ElementsPerAccess needs to be 16 or 8.");
|
||||
|
||||
using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
float,
|
||||
32,
|
||||
cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementsPerAccess,
|
||||
8
|
||||
>;
|
||||
|
||||
using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
float,
|
||||
layout::RowMajor
|
||||
>;
|
||||
|
||||
using WarpTileIterator = typename platform::conditional<
|
||||
(ThreadblockShape::kN == 256),
|
||||
WarpTileIteratorNotMixed,
|
||||
WarpTileIteratorMixed>::type;
|
||||
|
||||
using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||
ThreadMap,
|
||||
float,
|
||||
32,
|
||||
cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementsPerAccess,
|
||||
8
|
||||
>;
|
||||
|
||||
using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
ThreadMap,
|
||||
float
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = typename platform::conditional<
|
||||
(ThreadblockShape::kN == 256),
|
||||
SharedLoadIteratorNotMixed,
|
||||
SharedLoadIteratorMixed>::type;
|
||||
|
||||
static int const kFragmentsPerIteration = 1;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,177 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename ElementTensor,
|
||||
typename ElementVector,
|
||||
typename OutputOp,
|
||||
int ElementsPerAccess,
|
||||
bool ScatterD = false
|
||||
>
|
||||
struct DefaultEpilogueWithBroadcastTensorOpV2 {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
//
|
||||
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
||||
//
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput,
|
||||
ScatterD
|
||||
>;
|
||||
|
||||
//
|
||||
// Additional tensor tile iterator - stores t = Elementwise(z)
|
||||
//
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementTensor
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithBroadcastV2<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
ElementVector,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
OutputOp,
|
||||
typename Base::Padding,
|
||||
Base::kFragmentsPerIteration
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for VoltaTensorOps.
|
||||
template <
|
||||
typename Shape,
|
||||
typename WarpMmaTensorOp,
|
||||
int PartitionsK,
|
||||
typename ElementOutput,
|
||||
typename ElementTensor,
|
||||
typename ElementVector,
|
||||
typename OutputOp,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueWithBroadcastVoltaTensorOpV2 {
|
||||
|
||||
/// Use defaults related to the existing epilogue
|
||||
using Base = DefaultEpilogueVoltaTensorOp<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputOp,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
//
|
||||
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
||||
//
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
//
|
||||
// Additional tensor tile iterator - stores t = Elementwise(z)
|
||||
//
|
||||
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2<
|
||||
typename Base::OutputTileThreadMap,
|
||||
ElementTensor
|
||||
>;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = EpilogueWithBroadcastV2<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
PartitionsK,
|
||||
OutputTileIterator,
|
||||
TensorTileIterator,
|
||||
ElementVector,
|
||||
typename Base::AccumulatorFragmentIterator,
|
||||
typename Base::WarpTileIterator,
|
||||
typename Base::SharedLoadIterator,
|
||||
OutputOp,
|
||||
typename Base::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -34,6 +34,7 @@
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
The shared memory resource is time-sliced across warps.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -59,8 +60,9 @@
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/util/index_sequence.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -68,6 +70,7 @@ namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator
|
||||
@ -85,27 +88,39 @@ template <
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
||||
>
|
||||
class Epilogue :
|
||||
class Epilogue :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition> {
|
||||
FragmentsPerPartition>,
|
||||
public EpilogueBaseStreamK<
|
||||
Shape_,
|
||||
PartitionsK,
|
||||
WarpMmaOperator_,
|
||||
AccumulatorFragmentIterator_>
|
||||
{
|
||||
|
||||
public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
using BaseStreamK = EpilogueBaseStreamK<
|
||||
Shape_,
|
||||
PartitionsK,
|
||||
WarpMmaOperator_,
|
||||
AccumulatorFragmentIterator_>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
@ -115,15 +130,23 @@ public:
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
/// Number of warps per block
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
/// Number of threads per block
|
||||
static int const kBlockThreads = 32 * WarpCount::kCount;
|
||||
|
||||
/// Per-thread accumulator tile type
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
/// Numerical accumulation element type
|
||||
using ElementAccumulator = typename WarpMmaOperator::ElementC;
|
||||
|
||||
/// Fragment type used by the accumulator tile's fragment iterator
|
||||
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
@ -140,21 +163,20 @@ public:
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
/// Vector type used by the global output iterator
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
/// Vector type used by the shared output iterator
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
|
||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
public:
|
||||
|
||||
|
||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
@ -163,144 +185,177 @@ public:
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
||||
|
||||
private:
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
/// Thread index in the threadblock
|
||||
int thread_idx;
|
||||
|
||||
/// Warp index in the threadblock
|
||||
int warp_idx;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
Epilogue(
|
||||
typename Base::SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
shared_load_iterator_(shared_storage.reference(), thread_idx)
|
||||
typename Base::SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx) ///< Id of thread within warp
|
||||
:
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
BaseStreamK(thread_idx),
|
||||
shared_load_iterator_(shared_storage.reference(), thread_idx),
|
||||
thread_idx(thread_idx),
|
||||
warp_idx(warp_idx)
|
||||
{}
|
||||
|
||||
|
||||
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
|
||||
/// performing epilogue computations, writing to output
|
||||
CUTLASS_DEVICE
|
||||
void reduce(
|
||||
int peer_idx_begin,
|
||||
int peer_idx_end,
|
||||
int reduce_fragment_idx,
|
||||
ElementAccumulator *element_workspace,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
{
|
||||
|
||||
// Redcuce peer accumulator fragments into one fragment
|
||||
AccumulatorFragment accum_fragment;
|
||||
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
|
||||
|
||||
// Store fragment to shared memory
|
||||
this->warp_tile_iterator_.store(accum_fragment);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Initialize/load source-fragment data
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
source_fragment.clear();
|
||||
|
||||
if (output_op.is_source_needed())
|
||||
{
|
||||
source_iterator += reduce_fragment_idx;
|
||||
source_iterator.load(source_fragment);
|
||||
}
|
||||
|
||||
// Load fragment from shared memory
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
||||
shared_load_iterator_.load(aligned_accum_fragment);
|
||||
|
||||
// Add fragments shared by other k partitions
|
||||
if (kPartitionsK > 1)
|
||||
{
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
typename SharedLoadIterator::Fragment aligned_addend_fragment;
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_addend_fragment);
|
||||
aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the output result
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
// Apply the output operator
|
||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
||||
|
||||
// Store the final result
|
||||
destination_iterator += reduce_fragment_idx;
|
||||
destination_iterator.store(output_fragment);
|
||||
}
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
{
|
||||
if (!output_op.is_source_needed())
|
||||
{
|
||||
source_iterator.clear_mask();
|
||||
__syncthreads(); // Dummy (CUDA 11.0)
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Source-fragment data (zero-initialized for scenarios where the
|
||||
// output operator allows us to skip loading it from global input)
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
source_fragment.clear();
|
||||
|
||||
template <class Seq>
|
||||
struct acc2smem_source_not_needed;
|
||||
// Iterator over warp-level accumulator fragment
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration)
|
||||
{
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
|
||||
{
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
this->warp_tile_iterator_.store(accum_fragment);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||
(1 - Base::kFragmentsPerIteration));
|
||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
||||
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
|
||||
CUTLASS_UNUSED(dummy[0]);
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
|
||||
) {
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
acc2smem_source_not_needed<
|
||||
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
||||
Base::kFragmentsPerIteration>>::push(iter,
|
||||
accum_fragment_iterator,
|
||||
this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
__syncthreads();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
|
||||
{
|
||||
// Load addend source fragment from global memory
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
if (p < Base::kFragmentsPerIteration - 1)
|
||||
{
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
else if (kPartitionsK > 1) {
|
||||
|
||||
else if (kPartitionsK > 1)
|
||||
{
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -318,9 +373,7 @@ private:
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment[0]);
|
||||
|
||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
@ -336,170 +389,37 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
source_fragment.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
if (kPartitionsK > 1) {
|
||||
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the output result
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
|
||||
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
destination_iterator.store(output_fragment);
|
||||
++destination_iterator;
|
||||
|
||||
}
|
||||
}
|
||||
private:
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment) {
|
||||
|
||||
OutputAccessType *output_frag_ptr =
|
||||
typename OutputTileIterator::Fragment const &source_fragment)
|
||||
{
|
||||
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
OutputAccessType const *source_frag_ptr =
|
||||
OutputAccessType const *source_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||
{
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
|
||||
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
191
include/cutlass/epilogue/threadblock/epilogue_base_streamk.h
Normal file
191
include/cutlass/epilogue/threadblock/epilogue_base_streamk.h
Normal file
@ -0,0 +1,191 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Basic subset of epilogue functionality for supporting StreamK decompositions
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/block_striped.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// StreamK epilogue functionality for cross-block accumulator fragment reduction
|
||||
template <
|
||||
typename Shape, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
int PartitionsK,
|
||||
typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
typename AccumulatorFragmentIterator> ///< Fragment iterator selecting accumulators
|
||||
class EpilogueBaseStreamK
|
||||
{
|
||||
|
||||
protected:
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
Shape::kM / WarpMmaOperator::Shape::kM,
|
||||
Shape::kN / WarpMmaOperator::Shape::kN, PartitionsK>;
|
||||
|
||||
/// Number of threads per block
|
||||
static int const kBlockThreads = 32 * WarpCount::kCount;
|
||||
|
||||
/// Numerical accumulation element type
|
||||
using ElementAccumulator = typename WarpMmaOperator::ElementC;
|
||||
|
||||
/// Fragment type used by the accumulator tile's fragment iterator
|
||||
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
|
||||
|
||||
/// Block-striped transfer utility for sharing AccumulatorFragment
|
||||
using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment, ElementAccumulator>;
|
||||
|
||||
/// Number of elements per fragment
|
||||
static int const kFragmentElements = sizeof(AccumulatorFragment) / sizeof(ElementAccumulator);
|
||||
|
||||
public:
|
||||
|
||||
/// Number of fragments per accumulator tile
|
||||
static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations;
|
||||
|
||||
/// Number of workspace accumulation elements shared per output tile
|
||||
static int const kPeerAccumulators = WarpMmaOperator::Shape::kMN * WarpCount::kCount;
|
||||
|
||||
protected:
|
||||
|
||||
/// ElementAccumulator stride in the shared workspace between different peer blocks (two: each peer block can share accumulators for up to two tiles)
|
||||
static const int kPeerStride = kPeerAccumulators * 2;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Thread index in the threadblock
|
||||
int thread_idx;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueBaseStreamK(
|
||||
int thread_idx) ///< ID of a thread within the threadblock
|
||||
:
|
||||
thread_idx(thread_idx)
|
||||
{}
|
||||
|
||||
|
||||
/// Aggregates the accumulator sets shared by peer blocks in the global workspace
|
||||
CUTLASS_DEVICE
|
||||
void reduce(
|
||||
AccumulatorFragment &accum_fragment, ///< [out] sum of all shared accumulator fragments for these peer partials
|
||||
int peer_idx_begin,
|
||||
int peer_idx_end,
|
||||
int reduce_fragment_idx,
|
||||
ElementAccumulator *element_workspace)
|
||||
{
|
||||
plus<AccumulatorFragment> add_fragments;
|
||||
|
||||
int accum_set_offset =
|
||||
(peer_idx_begin * kPeerStride) +
|
||||
(reduce_fragment_idx * kBlockThreads * kFragmentElements);
|
||||
|
||||
// Load first peer fragment
|
||||
BlockStripedT::load(accum_fragment, element_workspace + accum_set_offset, this->thread_idx);
|
||||
|
||||
accum_set_offset += kPeerStride; // Move to next peer
|
||||
accum_set_offset += kPeerAccumulators; // Move to non-starting accumulator set for peer
|
||||
|
||||
// Reduce additional peer fragments
|
||||
#pragma unroll 2
|
||||
while (accum_set_offset < peer_idx_end * kPeerStride)
|
||||
{
|
||||
AccumulatorFragment addend_fragment;
|
||||
BlockStripedT::load(addend_fragment, element_workspace + accum_set_offset, this->thread_idx);
|
||||
accum_set_offset += kPeerStride;
|
||||
|
||||
accum_fragment = add_fragments(accum_fragment, addend_fragment);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Shares the accumulator set with peers in the global workspace
|
||||
CUTLASS_DEVICE
|
||||
void share(
|
||||
int peer_idx,
|
||||
ElementAccumulator *element_workspace, ///< Output pointer for writing this block's accumulator set to
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
bool started_tile)
|
||||
{
|
||||
int accum_set_offset = peer_idx * kPeerStride;
|
||||
|
||||
if (!started_tile) {
|
||||
// Move to non-starting accumulator set
|
||||
accum_set_offset += kPeerAccumulators;
|
||||
}
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < kAccumulatorFragments; ++iter)
|
||||
{
|
||||
// Acquire reordered accumulator fragment
|
||||
AccumulatorFragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
// Store accumulator fragment
|
||||
BlockStripedT::store(element_workspace + accum_set_offset, accum_fragment, this->thread_idx);
|
||||
|
||||
accum_set_offset += (kFragmentElements * kBlockThreads);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
335
include/cutlass/epilogue/threadblock/epilogue_depthwise.h
Normal file
335
include/cutlass/epilogue/threadblock/epilogue_depthwise.h
Normal file
@ -0,0 +1,335 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for Depthwise convoltuion
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator
|
||||
template <typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename ThreadOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename ThreadBlockOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
|
||||
///< gemm::warp::MmaTensorOp)
|
||||
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||
typename OutputOp_, ///< Output operator
|
||||
typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept:
|
||||
///< MatrixShape)
|
||||
>
|
||||
class EpilogueDepthwise {
|
||||
public:
|
||||
using Shape = Shape_;
|
||||
using WarpShape = typename WarpMmaOperator_::Shape;
|
||||
using ThreadOutputShape = ThreadOutputShape_;
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType =
|
||||
Array<typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType =
|
||||
Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount =
|
||||
gemm::GemmShape<Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN>;
|
||||
|
||||
public:
|
||||
static_assert(SharedLoadIterator::Fragment::kElements ==
|
||||
OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess,
|
||||
"OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
/// Shared storage allocation needed by the epilogue
|
||||
struct SharedStorage {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Element type of shared memory
|
||||
using Element = typename WarpTileIterator::Element;
|
||||
|
||||
/// Tensor reference to shared memory allocation
|
||||
using TensorRef = typename WarpTileIterator::TensorRef;
|
||||
|
||||
/// Layout of shared memory allocation
|
||||
using Layout = typename WarpTileIterator::Layout;
|
||||
|
||||
/// Logical shape of the shared memory tile written to by all warps.
|
||||
using Shape = MatrixShape<ThreadBlockOutputShape::kNHW, ThreadBlockOutputShape::kC>;
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = MatrixShape<Shape::kRow, Shape::kColumn>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
AlignedBuffer<Element, StorageShape::kCount> storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a pointer to the shared memory buffer
|
||||
CUTLASS_DEVICE
|
||||
Element *data() { return storage.data(); }
|
||||
|
||||
/// Returns a tensor reference to the shared memory buffer
|
||||
CUTLASS_DEVICE
|
||||
TensorRef reference() {
|
||||
return TensorRef(storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
/// Stores a warp's fragment of accumulators to SMEM
|
||||
WarpTileIterator warp_tile_iterator_;
|
||||
|
||||
LongIndex warp_offset;
|
||||
int thread_idx;
|
||||
int warp_idx;
|
||||
int lane_idx;
|
||||
int warp_m, warp_n; // warp coordinates within a cta
|
||||
int tid_m, tid_n; // thread coordinates within a warp
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueDepthwise(SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx_, ///< ID of a thread within the threadblock
|
||||
int warp_idx_, ///< ID of warp within threadblock
|
||||
int lane_idx_ ///< Id of thread within warp
|
||||
)
|
||||
: thread_idx(thread_idx_),
|
||||
warp_idx(warp_idx_),
|
||||
lane_idx(lane_idx_),
|
||||
shared_load_iterator_(shared_storage.reference(), thread_idx_),
|
||||
warp_tile_iterator_(shared_storage.reference(), thread_idx_, lane_idx_) {}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in
|
||||
///< units of threadblock tiles)
|
||||
const int smem_base_offset) { ///< SMEM base offset for epilogue operation
|
||||
// initiate the smem base offset for different output tile.
|
||||
warp_tile_iterator_.set_smem_base_address(smem_base_offset);
|
||||
|
||||
shared_load_iterator_.set_smem_base_address(smem_base_offset);
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
||||
} else {
|
||||
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
source_fragment.clear();
|
||||
|
||||
source_iterator.load(source_fragment);
|
||||
|
||||
// store to smem
|
||||
warp_tile_iterator_.store(accumulators);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
||||
|
||||
// load from smem
|
||||
shared_load_iterator_.load(aligned_accum_fragment);
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
apply_output_operator_(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
||||
|
||||
// Store to GMEM
|
||||
destination_iterator.store(output_fragment);
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
// store to smem
|
||||
warp_tile_iterator_.store(accumulators);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
||||
|
||||
// load from smem
|
||||
shared_load_iterator_.load(aligned_accum_fragment);
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment);
|
||||
|
||||
// Store to GMEM
|
||||
destination_iterator.store(output_fragment);
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment) {
|
||||
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
OutputAccessType const *source_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
|
||||
OutputAccessType *output_frag_ptr = reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -77,7 +77,6 @@ public:
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = MatrixShape<0, 0>;
|
||||
|
||||
|
||||
@ -133,7 +133,8 @@ struct EpilogueWithBroadcastOpBase {
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentC const &frag_C,
|
||||
FragmentC const &frag_C1,
|
||||
FragmentC const &frag_C2,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
@ -180,9 +181,42 @@ template <
|
||||
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
|
||||
bool IsSingleSource = OutputOp_::kIsSingleSource
|
||||
>
|
||||
class EpilogueWithBroadcast :
|
||||
class EpilogueWithBroadcast;
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpMmaOperator_,
|
||||
int PartitionsK,
|
||||
typename OutputTileIterator_,
|
||||
typename TensorTileIterator_,
|
||||
typename ElementVector_,
|
||||
typename AccumulatorFragmentIterator_,
|
||||
typename WarpTileIterator_,
|
||||
typename SharedLoadIterator_,
|
||||
typename OutputOp_,
|
||||
typename Padding_,
|
||||
int FragmentsPerPartition,
|
||||
int IterationsUnroll
|
||||
>
|
||||
class EpilogueWithBroadcast<
|
||||
Shape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
OutputTileIterator_,
|
||||
TensorTileIterator_,
|
||||
ElementVector_,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
SharedLoadIterator_,
|
||||
OutputOp_,
|
||||
Padding_,
|
||||
FragmentsPerPartition,
|
||||
IterationsUnroll,
|
||||
false
|
||||
> :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
@ -203,6 +237,7 @@ public:
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
static bool const kIsSingleSource = false;
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
@ -383,7 +418,687 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
||||
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
||||
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord(Shape::kM, Shape::kN),
|
||||
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||
MatrixCoord()) {
|
||||
|
||||
BroadcastFragment broadcast_fragment;
|
||||
|
||||
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
tensor_iterator);
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
source_iterator1,
|
||||
source_iterator2,
|
||||
tensor_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_(
|
||||
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
|
||||
) {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
ComputeFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
struct acc2smem_source_not_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||
(1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
||||
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
|
||||
CUTLASS_UNUSED(dummy[0]);
|
||||
}
|
||||
};
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_not_needed<
|
||||
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
||||
Base::kFragmentsPerIteration>>::push(iter,
|
||||
accum_fragment_iterator,
|
||||
this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
else if (kPartitionsK > 1) {
|
||||
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_source_not_needed_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
broadcast_fragment);
|
||||
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
||||
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment1;
|
||||
source_fragment1.clear();
|
||||
typename OutputTileIterator::Fragment source_fragment2;
|
||||
source_fragment2.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator1.load(source_fragment1);
|
||||
++source_iterator1;
|
||||
|
||||
source_iterator2.load(source_fragment2);
|
||||
++source_iterator2;
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
if (kPartitionsK > 1)
|
||||
{
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
source_fragment1,
|
||||
source_fragment2,
|
||||
broadcast_fragment);
|
||||
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
typename OutputTileIterator::Fragment const &frag_C1,
|
||||
typename OutputTileIterator::Fragment const &frag_C2,
|
||||
BroadcastFragment const &frag_Broadcast) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
OutputAccessType const *frag_C1_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C1);
|
||||
|
||||
OutputAccessType const *frag_C2_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C2);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C1_ptr[i],
|
||||
frag_C2_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
BroadcastFragment const &frag_Broadcast) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpMmaOperator_,
|
||||
int PartitionsK,
|
||||
typename OutputTileIterator_,
|
||||
typename TensorTileIterator_,
|
||||
typename ElementVector_,
|
||||
typename AccumulatorFragmentIterator_,
|
||||
typename WarpTileIterator_,
|
||||
typename SharedLoadIterator_,
|
||||
typename OutputOp_,
|
||||
typename Padding_,
|
||||
int FragmentsPerPartition,
|
||||
int IterationsUnroll
|
||||
>
|
||||
class EpilogueWithBroadcast<
|
||||
Shape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
OutputTileIterator_,
|
||||
TensorTileIterator_,
|
||||
ElementVector_,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
SharedLoadIterator_,
|
||||
OutputOp_,
|
||||
Padding_,
|
||||
FragmentsPerPartition,
|
||||
IterationsUnroll,
|
||||
true
|
||||
> :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition> {
|
||||
|
||||
public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
static bool const kIsSingleSource = true;
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using TensorTileIterator = TensorTileIterator_;
|
||||
using ElementVector = ElementVector_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Compute data type produced by the output op
|
||||
using ElementCompute = typename OutputOp::ElementCompute;
|
||||
|
||||
/// Compute fragment
|
||||
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementCompute,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Data type of additional tensor
|
||||
using ElementTensor = typename TensorTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Tensor access type
|
||||
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
/// Shared memory allocation from epilogue base class
|
||||
using BaseSharedStorage = typename Base::SharedStorage;
|
||||
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = kWarpSize * WarpCount::kCount;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// I'm not sure what I meant here.
|
||||
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
Shape::kN
|
||||
>;
|
||||
|
||||
/// Debug printing
|
||||
CUTLASS_DEVICE
|
||||
static void print() {
|
||||
#if 0
|
||||
printf("BroadcastDetail {\n");
|
||||
printf(
|
||||
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
||||
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
|
||||
kColumnsPerThread,
|
||||
kRowsPerThread,
|
||||
kThreadCount,
|
||||
kThreadsPerRow,
|
||||
kThreadRows,
|
||||
kThreadAccessesPerRow,
|
||||
StorageShape::kRow,
|
||||
StorageShape::kColumn,
|
||||
StorageShape::kCount
|
||||
);
|
||||
printf("};\n");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
||||
struct SharedStorage {
|
||||
union {
|
||||
BaseSharedStorage base;
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
|
||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
private:
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
/// Thread index within the threadblock
|
||||
int thread_idx_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueWithBroadcast(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
|
||||
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
||||
@ -646,7 +1361,7 @@ private:
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
@ -759,7 +1474,7 @@ private:
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
OutputAccessType const *frag_C_ptr =
|
||||
OutputAccessType const *frag_C_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
@ -770,13 +1485,12 @@ private:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,847 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h"
|
||||
|
||||
#include "cutlass/util/index_sequence.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This base class is meant to define the concept required of the
|
||||
/// EpilogueWithBroadcast::OutputOp
|
||||
template <
|
||||
typename ElementC_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementZ_,
|
||||
typename ElementT_,
|
||||
int ElementsPerAccess,
|
||||
bool StoreZ = true,
|
||||
bool StoreT = true
|
||||
>
|
||||
struct EpilogueWithBroadcastOpBaseV2 {
|
||||
|
||||
using ElementOutput = ElementC_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementZ = ElementZ_;
|
||||
using ElementT = ElementT_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
||||
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
||||
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
||||
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
||||
|
||||
/// If true, the 'Z' tensor is stored
|
||||
static bool const kStoreZ = StoreZ;
|
||||
|
||||
/// If true, the 'T' tensor is stored
|
||||
static bool const kStoreT = StoreT;
|
||||
|
||||
/// Parameters structure - required
|
||||
struct Params { };
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor from Params
|
||||
EpilogueWithBroadcastOpBaseV2(Params const ¶ms_) { }
|
||||
|
||||
/// Determine if the source is needed. May return false if
|
||||
bool is_source_needed() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) { }
|
||||
|
||||
/// Applies the operation when is_source_needed() is true
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentC const &frag_C1,
|
||||
FragmentC const &frag_C2,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
|
||||
/// Applies the operation when is_source_needed() is false
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(
|
||||
FragmentZ &frag_Z,
|
||||
FragmentT &frag_T,
|
||||
FragmentAccumulator const &AB,
|
||||
FragmentCompute const &V) const {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue operator with bias vector broadcast over columns.
|
||||
///
|
||||
/// Computes the following:
|
||||
///
|
||||
///
|
||||
/// Z, T = OutputOp(AB, C, Broadcast)
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreZ) {
|
||||
/// store(converted_u);
|
||||
/// }
|
||||
///
|
||||
/// if (ElementwiseOp::kStoreT) {
|
||||
/// store(v);
|
||||
/// }
|
||||
///
|
||||
template <
|
||||
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
int PartitionsK, ///< Number of partitions of the K dimension
|
||||
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
|
||||
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
|
||||
typename ElementVector_, ///< Pointer to broadcast vector
|
||||
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
|
||||
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
||||
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
|
||||
>
|
||||
class EpilogueWithBroadcastV2 :
|
||||
public EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition> {
|
||||
|
||||
public:
|
||||
|
||||
using Base = EpilogueBase<
|
||||
Shape_,
|
||||
typename WarpMmaOperator_::Shape,
|
||||
PartitionsK,
|
||||
AccumulatorFragmentIterator_,
|
||||
WarpTileIterator_,
|
||||
Padding_,
|
||||
FragmentsPerPartition>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using TensorTileIterator = TensorTileIterator_;
|
||||
using ElementVector = ElementVector_;
|
||||
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
|
||||
using WarpTileIterator = WarpTileIterator_;
|
||||
using SharedLoadIterator = SharedLoadIterator_;
|
||||
using OutputOp = OutputOp_;
|
||||
using Padding = Padding_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename WarpTileIterator::Element;
|
||||
|
||||
/// Compute data type produced by the output op
|
||||
using ElementCompute = typename OutputOp::ElementCompute;
|
||||
|
||||
/// Compute fragment
|
||||
using FragmentCompute = Array<ElementCompute, OutputTileIterator::Fragment::kElements>;
|
||||
|
||||
/// Thread map used by output tile iterators
|
||||
using ThreadMap = typename OutputTileIterator::ThreadMap;
|
||||
|
||||
/// Fragment object used to store the broadcast values
|
||||
using BroadcastFragment = Array<
|
||||
ElementCompute,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Output element
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
|
||||
/// Data type of additional tensor
|
||||
using ElementTensor = typename TensorTileIterator::Element;
|
||||
|
||||
/// Output access size
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Tensor reference to destination tensor
|
||||
using TensorRef = typename OutputTileIterator::TensorRef;
|
||||
|
||||
/// Tensor reference to sync tensor
|
||||
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
|
||||
|
||||
/// Const tensor reference to source tensor
|
||||
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
|
||||
|
||||
/// Array type used to output
|
||||
using OutputAccessType = Array<
|
||||
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Array type used by output functor
|
||||
using ComputeAccessType = Array<ElementCompute, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Tensor access type
|
||||
using TensorAccessType = Array<ElementTensor, OutputTileIterator::kElementsPerAccess>;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
/// Shared memory allocation from epilogue base class
|
||||
using BaseSharedStorage = typename Base::SharedStorage;
|
||||
|
||||
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
|
||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
/// Used for the broadcast
|
||||
struct BroadcastDetail {
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar column indices handled by each thread
|
||||
static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess;
|
||||
|
||||
/// Number of distinct scalar row indices handled by each thread
|
||||
static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Number of threads per threadblock
|
||||
static int const kThreadCount = kWarpSize * WarpCount::kCount;
|
||||
|
||||
/// Number of distinct threads per row of output tile
|
||||
static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread);
|
||||
|
||||
/// Number of distinct threads which must be reduced during the final reduction phase within the threadblock.
|
||||
static int const kThreadRows = kThreadCount / kThreadsPerRow;
|
||||
|
||||
/// I'm not sure what I meant here.
|
||||
static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount);
|
||||
|
||||
/// Shape of the shared memory allocation for the epilogue
|
||||
using StorageShape = MatrixShape<
|
||||
kThreadRows,
|
||||
Shape::kN
|
||||
>;
|
||||
|
||||
/// Debug printing
|
||||
CUTLASS_DEVICE
|
||||
static void print() {
|
||||
#if 0
|
||||
printf("BroadcastDetail {\n");
|
||||
printf(
|
||||
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
||||
"kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n",
|
||||
kColumnsPerThread,
|
||||
kRowsPerThread,
|
||||
kThreadCount,
|
||||
kThreadsPerRow,
|
||||
kThreadRows,
|
||||
kThreadAccessesPerRow,
|
||||
StorageShape::kRow,
|
||||
StorageShape::kColumn,
|
||||
StorageShape::kCount
|
||||
);
|
||||
printf("};\n");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
||||
struct SharedStorage {
|
||||
union {
|
||||
BaseSharedStorage base;
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
SharedStorage() { }
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
|
||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
|
||||
|
||||
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
private:
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
SharedLoadIterator shared_load_iterator_;
|
||||
|
||||
/// Thread index within the threadblock
|
||||
int thread_idx_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpilogueWithBroadcastV2(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
):
|
||||
Base(shared_storage.base, thread_idx, warp_idx, lane_idx),
|
||||
shared_load_iterator_(shared_storage.base.reference(), thread_idx),
|
||||
thread_idx_(thread_idx)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator1, ///< Tile iterator for source accumulator matrix
|
||||
OutputTileIterator source_iterator2, ///< Tile iterator for source accumulator matrix
|
||||
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord(Shape::kM, Shape::kN),
|
||||
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||
MatrixCoord()) {
|
||||
|
||||
BroadcastFragment broadcast_fragment;
|
||||
|
||||
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
tensor_iterator);
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(
|
||||
output_op,
|
||||
broadcast_fragment,
|
||||
destination_iterator,
|
||||
accumulators,
|
||||
source_iterator1,
|
||||
source_iterator2,
|
||||
tensor_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load_broadcast_fragment_(
|
||||
BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||
MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses
|
||||
MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space
|
||||
) {
|
||||
|
||||
broadcast_fragment.clear();
|
||||
|
||||
// If no pointer is supplied, set with all zeros and avoid memory accesses
|
||||
if (!broadcast_ptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column();
|
||||
|
||||
int thread_column_idx = threadblock_offset.column() + thread_initial_column;
|
||||
broadcast_ptr += thread_initial_column;
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementVector, BroadcastDetail::kElementsPerAccess> converter;
|
||||
using AccessType = AlignedArray<ElementVector, BroadcastDetail::kElementsPerAccess>;
|
||||
using ComputeFragmentType = Array<ElementCompute, BroadcastDetail::kElementsPerAccess>;
|
||||
|
||||
ComputeFragmentType *frag_ptr = reinterpret_cast<ComputeFragmentType *>(&broadcast_fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) {
|
||||
|
||||
AccessType loaded;
|
||||
|
||||
loaded.clear();
|
||||
|
||||
if (thread_column_idx < problem_size.column()) {
|
||||
loaded = *reinterpret_cast<AccessType const *>(broadcast_ptr);
|
||||
}
|
||||
|
||||
ComputeFragmentType cvt = converter(loaded);
|
||||
frag_ptr[j] = cvt;
|
||||
|
||||
thread_column_idx += ThreadMap::Delta::kColumn;
|
||||
broadcast_ptr += ThreadMap::Delta::kColumn;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
struct acc2smem_source_not_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
|
||||
template <int Advance>
|
||||
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||
(1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {
|
||||
(pos == (Seq * Base::kFragmentsPerIteration)) &&
|
||||
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
|
||||
CUTLASS_UNUSED(dummy[0]);
|
||||
}
|
||||
};
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) {
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_not_needed<
|
||||
cutlass::make_index_sequence<OutputTileIterator::kIterations /
|
||||
Base::kFragmentsPerIteration>>::push(iter,
|
||||
accum_fragment_iterator,
|
||||
this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
else if (kPartitionsK > 1) {
|
||||
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_source_not_needed_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
broadcast_fragment);
|
||||
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem_source_needed;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator1, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator2, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand
|
||||
) {
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment1;
|
||||
source_fragment1.clear();
|
||||
typename OutputTileIterator::Fragment source_fragment2;
|
||||
source_fragment2.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator1.load(source_fragment1);
|
||||
++source_iterator1;
|
||||
|
||||
if (source_iterator2.enabled()) {
|
||||
source_iterator2.load(source_fragment2);
|
||||
++source_iterator2;
|
||||
}
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
if (kPartitionsK > 1)
|
||||
{
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
||||
}
|
||||
|
||||
//
|
||||
// Apply output operation
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment frag_Z;
|
||||
typename TensorTileIterator::Fragment frag_T;
|
||||
|
||||
apply_output_operator_(
|
||||
frag_Z,
|
||||
frag_T,
|
||||
output_op,
|
||||
aligned_accum_fragment[0],
|
||||
source_fragment1,
|
||||
source_fragment2,
|
||||
broadcast_fragment,
|
||||
source_iterator2.enabled());
|
||||
//
|
||||
// Conditionally store fragments
|
||||
//
|
||||
|
||||
if (OutputOp::kStoreZ) {
|
||||
destination_iterator.store(frag_Z);
|
||||
++destination_iterator;
|
||||
}
|
||||
|
||||
if (OutputOp::kStoreT) {
|
||||
tensor_iterator.store(frag_T);
|
||||
++tensor_iterator;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
typename OutputTileIterator::Fragment const &frag_C1,
|
||||
typename OutputTileIterator::Fragment const &frag_C2,
|
||||
BroadcastFragment const &frag_Broadcast,
|
||||
bool frag_C2_enabled) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
OutputAccessType const *frag_C1_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C1);
|
||||
OutputAccessType const *frag_C2_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&frag_C2);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
if (frag_C2_enabled) {
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C1_ptr[i],
|
||||
frag_C2_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
} else {
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_C1_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
typename OutputTileIterator::Fragment &frag_Z,
|
||||
typename TensorTileIterator::Fragment &frag_T,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &frag_AB,
|
||||
BroadcastFragment const &frag_Broadcast) {
|
||||
|
||||
using AccessTypeZ = Array<typename OutputTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeT = Array<typename TensorTileIterator::Element, kElementsPerAccess>;
|
||||
using AccessTypeBroadcast = Array<ElementCompute, kElementsPerAccess>;
|
||||
|
||||
AccessTypeZ *frag_Z_ptr = reinterpret_cast<AccessTypeZ *>(&frag_Z);
|
||||
AccessTypeT *frag_T_ptr = reinterpret_cast<AccessTypeT *>(&frag_T);
|
||||
|
||||
AccumulatorAccessType const *frag_AB_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&frag_AB);
|
||||
|
||||
AccessTypeBroadcast const *frag_Broadcast_ptr =
|
||||
reinterpret_cast<AccessTypeBroadcast const *>(&frag_Broadcast);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
output_op(
|
||||
frag_Z_ptr[i],
|
||||
frag_T_ptr[i],
|
||||
frag_AB_ptr[i],
|
||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -124,6 +124,8 @@ public:
|
||||
using Layout = layout::RowMajor;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
static bool const kIsSingleSource = true;
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename Base::AccumulatorTile;
|
||||
|
||||
@ -294,7 +296,7 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
ElementVector * reduction_output_ptr, ///< Reduction output vector
|
||||
ElementVector * reduction_output_ptr, ///< Reduction output vector
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix
|
||||
|
||||
@ -51,7 +51,7 @@
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -78,8 +78,21 @@ template <
|
||||
typename OutputOp_,
|
||||
/// Number of interleaved k
|
||||
int InterleavedK>
|
||||
class InterleavedEpilogue {
|
||||
public:
|
||||
class InterleavedEpilogue :
|
||||
public EpilogueBaseStreamK<
|
||||
Shape_,
|
||||
PartitionsK,
|
||||
WarpMmaOperator_,
|
||||
AccumulatorFragmentIterator_>
|
||||
{
|
||||
public:
|
||||
|
||||
using BaseStreamK = EpilogueBaseStreamK<
|
||||
Shape_,
|
||||
PartitionsK,
|
||||
WarpMmaOperator_,
|
||||
AccumulatorFragmentIterator_>;
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaOperator = WarpMmaOperator_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
@ -90,6 +103,9 @@ class InterleavedEpilogue {
|
||||
/// The complete warp-level accumulator tile
|
||||
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
||||
|
||||
/// Fragment type used by the accumulator tile's fragment iterator
|
||||
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
|
||||
|
||||
/// Accumulator element
|
||||
using ElementAccumulator = typename AccumulatorTile::Element;
|
||||
|
||||
@ -122,7 +138,8 @@ class InterleavedEpilogue {
|
||||
gemm::GemmShape<Shape::kM / WarpMmaOperator::Shape::kM,
|
||||
Shape::kN / WarpMmaOperator::Shape::kN, kPartitionsK>;
|
||||
|
||||
public:
|
||||
public:
|
||||
|
||||
static_assert(OutputTileIterator::kElementsPerAccess,
|
||||
"This must not be zero.");
|
||||
|
||||
@ -134,15 +151,58 @@ class InterleavedEpilogue {
|
||||
struct SharedStorage {};
|
||||
|
||||
|
||||
public:
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
InterleavedEpilogue(
|
||||
SharedStorage &shared_storage, ///< Shared storage object
|
||||
int thread_idx, ///< ID of a thread within the threadblock
|
||||
int warp_idx, ///< ID of warp within threadblock
|
||||
int lane_idx ///< Id of thread within warp
|
||||
) {}
|
||||
int lane_idx) ///< Id of thread within warp
|
||||
:
|
||||
BaseStreamK(thread_idx)
|
||||
{}
|
||||
|
||||
|
||||
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
|
||||
/// performing epilogue computations, writing to output
|
||||
CUTLASS_DEVICE
|
||||
void reduce(
|
||||
int peer_idx_begin,
|
||||
int peer_idx_end,
|
||||
int reduce_fragment_idx,
|
||||
ElementAccumulator *element_workspace,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
{
|
||||
// Redcuce peer accumulator fragments into one fragment
|
||||
AccumulatorFragment accum_fragment;
|
||||
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
|
||||
|
||||
// Source-fragment data (zero-initialized for scenarios where the
|
||||
// output operator allows us to skip loading it from global input)
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
source_fragment.clear();
|
||||
|
||||
if (output_op.is_source_needed())
|
||||
{
|
||||
source_iterator += reduce_fragment_idx;
|
||||
source_iterator.load(source_fragment);
|
||||
}
|
||||
|
||||
// Compute the output result
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
// Apply the output operator
|
||||
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
||||
|
||||
// Store the final result
|
||||
destination_iterator += reduce_fragment_idx;
|
||||
destination_iterator.store(output_fragment);
|
||||
}
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
@ -194,7 +254,7 @@ class InterleavedEpilogue {
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
apply_output_operator_source_not_needed_(output_op, output_fragment, accum_fragment);
|
||||
apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
@ -257,7 +317,7 @@ class InterleavedEpilogue {
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
apply_output_operator_source_needed_(output_op, output_fragment, accum_fragment, source_fragment);
|
||||
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
@ -269,15 +329,16 @@ class InterleavedEpilogue {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
typename AccumulatorFragmentIterator::Fragment const
|
||||
&aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment) {
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
@ -300,11 +361,11 @@ class InterleavedEpilogue {
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
typename AccumulatorFragmentIterator::Fragment const
|
||||
&aligned_accum_fragment) {
|
||||
void apply_output_operator_source_not_needed(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
|
||||
@ -680,6 +680,9 @@ public:
|
||||
state_[2] = 0;
|
||||
byte_pointer_ += params_.advance_tile;
|
||||
store_byte_pointer_ += params_.advance_tile;
|
||||
|
||||
thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow
|
||||
* ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -687,6 +690,60 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances a number of positions to load or store
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIterator &operator+=(int increment)
|
||||
{
|
||||
// Row
|
||||
state_[0] += increment;
|
||||
int increment_row = state_[0] / ThreadMap::Count::kRow;
|
||||
state_[0] = state_[0] % ThreadMap::Count::kRow;
|
||||
|
||||
byte_pointer_ += (params_.advance_row * increment);
|
||||
store_byte_pointer_ += (params_.advance_row * increment);
|
||||
thread_start_row_ += (ThreadMap::Shape::kRow * increment);
|
||||
|
||||
// Group
|
||||
state_[1] += increment_row;
|
||||
int increment_group = state_[1] / ThreadMap::Count::kGroup;
|
||||
state_[1] = state_[1] % ThreadMap::Count::kGroup;
|
||||
|
||||
byte_pointer_ += (params_.advance_group * increment_row);
|
||||
store_byte_pointer_ += (params_.advance_group * increment_row);
|
||||
thread_start_row_ +=
|
||||
(ThreadMap::Shape::kGroup - 1) *
|
||||
ThreadMap::Shape::kRow *
|
||||
ThreadMap::Count::kRow *
|
||||
increment_row;
|
||||
|
||||
|
||||
// Cluster
|
||||
state_[2] += increment_group;
|
||||
int increment_cluster = state_[2] / ThreadMap::Count::kCluster;
|
||||
state_[2] = state_[2] % ThreadMap::Count::kCluster;
|
||||
|
||||
byte_pointer_ += (params_.advance_cluster * increment_group);
|
||||
store_byte_pointer_ += (params_.advance_cluster * increment_group);
|
||||
thread_start_row_ +=
|
||||
ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup *
|
||||
ThreadMap::Count::kRow *
|
||||
ThreadMap::Shape::kRow *
|
||||
increment_group;
|
||||
|
||||
// Tile
|
||||
byte_pointer_ += (params_.advance_tile * increment_cluster);
|
||||
store_byte_pointer_ += (params_.advance_tile * increment_cluster);
|
||||
thread_start_row_ +=
|
||||
ThreadMap::Shape::kGroup *
|
||||
ThreadMap::Shape::kRow *
|
||||
ThreadMap::Shape::kCluster *
|
||||
ThreadMap::Shape::kTile *
|
||||
increment_cluster;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void clear_mask() {
|
||||
mask_.clear();
|
||||
@ -944,6 +1001,23 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances a number of positions to load or store
|
||||
CUTLASS_HOST_DEVICE
|
||||
InterleavedPredicatedTileIterator &operator+=(int increment)
|
||||
{
|
||||
// Contiguous
|
||||
iteration_contiguous_ += increment;
|
||||
int increment_strided = iteration_contiguous_ / ThreadMap::Iterations::kContiguous;
|
||||
iteration_contiguous_ = iteration_contiguous_ % ThreadMap::Iterations::kContiguous;
|
||||
byte_pointer_ += (params_.advance_row * increment);
|
||||
|
||||
// Strided
|
||||
iteration_strided_ += increment_strided;
|
||||
byte_pointer_ += (params_.advance_column * increment_strided);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void clear_mask() {
|
||||
mask_.clear();
|
||||
|
||||
@ -0,0 +1,445 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/layout/permute.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load and store output tile from global memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
|
||||
///
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>,
|
||||
typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>
|
||||
>
|
||||
class PredicatedTileIteratorDirectConv {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
using ThreadOutputShape = ThreadOutputShape_;
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
|
||||
using Element = Element_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
|
||||
using ConvProblemSize = typename cutlass::conv::Conv2dProblemSize;
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<Element, ThreadMap::Iterations::kCount * kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
||||
|
||||
static int const kLoadsPerAccess = AccessType::kElements / AccessType::kElements;
|
||||
|
||||
using ThreadTileCount = MatrixShape<
|
||||
ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
|
||||
ThreadBlockOutputShape::kW / ThreadOutputShape::kW
|
||||
>;
|
||||
|
||||
//
|
||||
// Parameters struct
|
||||
//
|
||||
|
||||
/// Uses a non-template class
|
||||
struct Params : PredicatedTileIteratorDirect2dConvParams {
|
||||
using Base = PredicatedTileIteratorDirect2dConvParams;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout, cutlass::conv::Conv2dProblemSize const &problem_size):
|
||||
PredicatedTileIteratorDirect2dConvParams(
|
||||
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
|
||||
problem_size,
|
||||
{ThreadBlockOutputShape::kH, ThreadBlockOutputShape::kW}
|
||||
)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Base const &base) :
|
||||
Base(base) { }
|
||||
};
|
||||
|
||||
/// Mask object
|
||||
struct Mask {
|
||||
|
||||
static int const kCount = ThreadMap::Iterations::kContiguous;
|
||||
|
||||
/// Predicate state
|
||||
bool predicates[kCount];
|
||||
|
||||
//
|
||||
// Mask
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mask() {
|
||||
enable();
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_HOST_DEVICE void clear() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure containing reference and precomputed state.
|
||||
PredicatedTileIteratorDirect2dConvParams params_;
|
||||
|
||||
/// Byte-level pointer
|
||||
uint8_t *byte_pointer_;
|
||||
|
||||
///
|
||||
Element *pointer_;
|
||||
|
||||
|
||||
/// Array of boolean values to contain steady-state predicates
|
||||
Mask mask_;
|
||||
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_row_;
|
||||
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_column_;
|
||||
|
||||
/// A thread's starting row position (assuming steady-state predicates have been computed)
|
||||
Index thread_start_row_;
|
||||
|
||||
/// A thread's starting column
|
||||
Index thread_start_column_;
|
||||
|
||||
/// Initial thread ouput location
|
||||
int thread_start_n_, thread_start_p_, thread_start_q_;
|
||||
|
||||
/// Current threadblock tile index
|
||||
int tile_index_;
|
||||
|
||||
//
|
||||
// Static asserts about internal strides
|
||||
//
|
||||
|
||||
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
|
||||
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
|
||||
static_assert(sizeof(PredicatedTileIteratorDirect2dConvParams::stride) == 8, "Expected 64b strides");
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
PredicatedTileIteratorDirectConv(
|
||||
PredicatedTileIteratorDirect2dConvParams const & params,
|
||||
Element *pointer,
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
TensorCoord threadblock_offset = TensorCoord()
|
||||
):
|
||||
params_(params), pointer_(pointer)
|
||||
{
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
extent_row_ = extent.row();
|
||||
extent_column_ = extent.column();
|
||||
|
||||
// stride dim (PQ)
|
||||
thread_start_row_ = thread_offset.column();
|
||||
// contiguous dim (Channels)
|
||||
thread_start_column_ = threadblock_offset.column() + thread_offset.row();
|
||||
|
||||
tile_index_ = threadblock_offset.row();
|
||||
|
||||
set_tile_index(0);
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_tile_index(const int index) {
|
||||
|
||||
int residual;
|
||||
params_.pq_divmod(thread_start_n_, residual, tile_index_ + index);
|
||||
params_.q_divmod(thread_start_p_, thread_start_q_, residual);
|
||||
|
||||
// Compute the base output coord of ThreadBlock
|
||||
thread_start_p_ *= ThreadBlockOutputShape::kH;
|
||||
thread_start_q_ *= ThreadBlockOutputShape::kW;
|
||||
|
||||
// Initialize predicates
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
mask_.predicates[c] = ((thread_start_column_
|
||||
+ c * ThreadMap::Delta::kContiguous) < extent_column_);
|
||||
}
|
||||
|
||||
// Null pointer performs no accesses
|
||||
if (!pointer_) {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c;
|
||||
|
||||
int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided;
|
||||
int p = current_row / ThreadBlockOutputShape::kW;
|
||||
int q = current_row % ThreadBlockOutputShape::kW;
|
||||
|
||||
int current_p = thread_start_p_ + p;
|
||||
int current_q = thread_start_q_ + q;
|
||||
|
||||
bool row_guard = (current_p) < params_.P && (current_q) < params_.Q &&
|
||||
(thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided;
|
||||
|
||||
int output_row_offset =
|
||||
thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q;
|
||||
|
||||
uint8_t *byte_pointer =
|
||||
reinterpret_cast<uint8_t *>(pointer_) +
|
||||
LongIndex(output_row_offset) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) *
|
||||
sizeof(AccessType) / kElementsPerAccess;
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||
|
||||
bool guard = row_guard && mask_.predicates[c];
|
||||
|
||||
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
load_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c;
|
||||
|
||||
int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided;
|
||||
int p = current_row / ThreadBlockOutputShape::kW;
|
||||
int q = current_row % ThreadBlockOutputShape::kW;
|
||||
|
||||
int current_p = thread_start_p_ + p;
|
||||
int current_q = thread_start_q_ + q;
|
||||
|
||||
bool row_guard = (current_p) < params_.P && (current_q) < params_.Q &&
|
||||
(thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided;
|
||||
|
||||
int output_row_offset =
|
||||
thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q;
|
||||
|
||||
uint8_t *byte_pointer =
|
||||
reinterpret_cast<uint8_t *>(pointer_) +
|
||||
LongIndex(output_row_offset) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) *
|
||||
sizeof(AccessType) / kElementsPerAccess;
|
||||
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||
|
||||
bool guard = row_guard && mask_.predicates[c];
|
||||
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) const {
|
||||
|
||||
store_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MatrixCoord thread_start() const {
|
||||
return MatrixCoord(thread_start_row_, thread_start_column_);
|
||||
}
|
||||
|
||||
/// Need to get the thread start row from the tile iterator
|
||||
CUTLASS_DEVICE
|
||||
int32_t thread_start_row() const {
|
||||
return thread_start_row_;
|
||||
}
|
||||
|
||||
/// Need to get the thread start row from the tile iterator
|
||||
CUTLASS_DEVICE
|
||||
int32_t thread_start_column() const {
|
||||
return thread_start_column_;
|
||||
}
|
||||
|
||||
/// Extent of the matrix in rows
|
||||
CUTLASS_DEVICE
|
||||
Index extent_row() const {
|
||||
return extent_row_;
|
||||
}
|
||||
|
||||
/// Extent of the matrix in columns
|
||||
CUTLASS_DEVICE
|
||||
Index extent_column() const {
|
||||
return extent_column_;
|
||||
}
|
||||
|
||||
/// Advances to the next position to load or store
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorDirectConv &operator++() {
|
||||
// do nothing
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void clear_mask() {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
///< Efficiently enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable_mask() {
|
||||
mask_.enable();
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void get_mask(Mask &mask) const {
|
||||
mask = mask_;
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void set_mask(Mask const &mask) {
|
||||
mask_ = mask;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -35,9 +35,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@ -245,6 +248,87 @@ struct PredicatedTileIteratorParams {
|
||||
};
|
||||
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Parameters struct for PredicatedTileIteratorDirect2dConv
|
||||
//
|
||||
|
||||
struct PredicatedTileIteratorDirect2dConvParams{
|
||||
using Index = int32_t;
|
||||
using LongIndex = int64_t;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
FastDivmod pq_divmod;
|
||||
FastDivmod q_divmod;
|
||||
|
||||
LongIndex stride;
|
||||
LongIndex stride_n;
|
||||
LongIndex stride_p;
|
||||
|
||||
int N;
|
||||
int P;
|
||||
int Q;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Status initialize(LongIndex stride_,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||
MatrixCoord threadblock_output_shape) {
|
||||
stride = stride_; // The stride per row of output tensor (bytes)
|
||||
stride_n = problem_size.P * problem_size.Q;
|
||||
stride_p = problem_size.Q ;
|
||||
|
||||
N = problem_size.N;
|
||||
P = problem_size.P;
|
||||
Q = problem_size.Q;
|
||||
|
||||
// Fastdivmod for output O, P, Q
|
||||
if(threadblock_output_shape.row() != 0 && threadblock_output_shape.column() !=0 ){
|
||||
int tiles_p =
|
||||
(problem_size.P + (threadblock_output_shape.row() - 1)) / (threadblock_output_shape.row());
|
||||
int tiles_q = (problem_size.Q + (threadblock_output_shape.column() - 1)) /
|
||||
(threadblock_output_shape.column());
|
||||
|
||||
pq_divmod = FastDivmod(tiles_p * tiles_q);
|
||||
q_divmod = FastDivmod(tiles_q);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Status initialize(
|
||||
Index stride_,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size = cutlass::conv::Conv2dProblemSize(),
|
||||
MatrixCoord threadblock_output_shape = MatrixCoord()) {
|
||||
return initialize(LongIndex(stride_), problem_size, threadblock_output_shape);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorDirect2dConvParams() { initialize(LongIndex(0)); }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorDirect2dConvParams(Index stride,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||
MatrixCoord threadblock_output_shape) {
|
||||
initialize(stride, problem_size, threadblock_output_shape);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PredicatedTileIteratorDirect2dConvParams(LongIndex stride,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size,
|
||||
MatrixCoord threadblock_output_shape) {
|
||||
initialize(stride, problem_size, threadblock_output_shape);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// InterleavedPredicatedTileIterator
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -201,6 +201,11 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
@ -234,6 +234,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Set base smem address
|
||||
CUTLASS_DEVICE
|
||||
void set_smem_base_address(Index address) {}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
@ -395,6 +399,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Set base smem address
|
||||
CUTLASS_DEVICE
|
||||
void set_smem_base_address(Index address) {}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
@ -556,6 +564,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Set base smem address
|
||||
CUTLASS_DEVICE
|
||||
void set_smem_base_address(Index address) {}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) {
|
||||
|
||||
@ -0,0 +1,194 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading.
|
||||
|
||||
When the fragment is loaded into registers, it matches the row-major thread map assumed by
|
||||
the predicated tile iterator writing to global memory.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
|
||||
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load output tile from shared memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
int MaxAlignment = ThreadMap_::kElementsPerAccess *sizeof_bits<Element_>::value / 8>
|
||||
class SharedLoadIteratorPitchLiner {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Element = Element_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
static int const kMinAlignment =
|
||||
ThreadMap_::kElementsPerAccess * sizeof_bits<Element_>::value / 8;
|
||||
|
||||
static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment);
|
||||
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<Element, ThreadMap::Iterations::kCount * kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, kElementsPerAccess, kAlignment>;
|
||||
|
||||
/// Vector type used for SMEM loads
|
||||
using LoadType =
|
||||
AlignedArray<Element,
|
||||
const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
|
||||
const_min(16, kAlignment)>;
|
||||
|
||||
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Byte-level pointer
|
||||
uint8_t *byte_pointer_;
|
||||
|
||||
/// Stride along adjacent rows
|
||||
int stride_;
|
||||
|
||||
/// Base address offset
|
||||
Index base_smem_address_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SharedLoadIteratorPitchLiner(TensorRef ref, int thread_idx)
|
||||
: byte_pointer_(reinterpret_cast<uint8_t *>(ref.data())),
|
||||
stride_((ref.stride(0) * sizeof_bits<Element>::value) / 8),
|
||||
base_smem_address_(0) {
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// Initialize pointer
|
||||
// thread_offset.row() is contiguous dim
|
||||
// thread_offset.column() is stride dim
|
||||
byte_pointer_ += thread_offset.row() * sizeof(AccessType) / kElementsPerAccess+
|
||||
thread_offset.column() * stride_ ;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const &offset) {
|
||||
byte_pointer_ +=
|
||||
offset.row() * ThreadMap::StorageShape::kContiguous * sizeof(AccessType) / kElementsPerAccess +
|
||||
offset.column() * ThreadMap::StorageShape::kStrided * stride_;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
||||
uint8_t const *byte_pointer =
|
||||
byte_pointer_ + s * ThreadMap::Delta::kStrided * stride_ +
|
||||
c * ThreadMap::Delta::kContiguous * ThreadMap::kElementsPerAccess *
|
||||
sizeof_bits<Element>::value / 8 +
|
||||
pointer_offset * sizeof_bits<Element>::value / 8 + base_smem_address_;
|
||||
|
||||
int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c;
|
||||
|
||||
LoadType *frag_ptr = reinterpret_cast<LoadType *>(&frag);
|
||||
|
||||
LoadType const *memory_pointer = reinterpret_cast<LoadType const *>(byte_pointer);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kLoadsPerAccess; ++v) {
|
||||
frag_ptr[frag_base_idx * kLoadsPerAccess + v] = memory_pointer[v];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void set_smem_base_address(Index address) { base_smem_address_ = address; }
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -240,12 +240,301 @@ public:
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template for reading and writing tiles of accumulators to shared memory
|
||||
template <typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
|
||||
typename Operator_, ///< matrix multiply operation (concept: arch::Mma)
|
||||
typename Element_, ///< data type of element to be written
|
||||
typename Layout_, ///< target shared memory layout
|
||||
typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy)
|
||||
>
|
||||
class TileIteratorSimtDirectConv {
|
||||
public:
|
||||
|
||||
using WarpShape = WarpShape_;
|
||||
using Operator = Operator_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>; ///< Tensor Reference object
|
||||
using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor
|
||||
using Index = typename TensorRef::Index;
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
using Policy = SimtPolicy<WarpShape, Operator, Layout, MmaSimtPolicy_>;
|
||||
|
||||
/// Shape of the tile in memory
|
||||
using Shape = MatrixShape<Policy::kRowsPerIteration, WarpShape::kN>;
|
||||
|
||||
/// This is the fragment size produced by one access of the iterator.
|
||||
using Fragment = Array<typename Operator::ElementC, Policy::kElementsPerIteration>;
|
||||
|
||||
/// This is the complete warp-level accumulator tile.
|
||||
using AccumulatorTile = Array<typename Operator::ElementC, Policy::kAccumulatorElementCount>;
|
||||
|
||||
/// Number of times this iterator can be incremented
|
||||
static int const kIterations = Policy::kIterations;
|
||||
|
||||
/// Padding quantity
|
||||
using Padding = MatrixShape<0,
|
||||
0
|
||||
>;
|
||||
|
||||
private:
|
||||
/// Storage type for accessing memory
|
||||
using AccessType = AlignedArray<
|
||||
Element,
|
||||
Policy::kElementsPerAccess
|
||||
>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Internal pointer to memory
|
||||
AccessType *pointer_;
|
||||
|
||||
/// Internal layout object
|
||||
Layout layout_;
|
||||
|
||||
/// Base smem offset;
|
||||
Index base_smem_address_;
|
||||
|
||||
public:
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirectConv() : pointer_(nullptr) {}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirectConv(
|
||||
TensorRef const &ref,
|
||||
unsigned lane_id
|
||||
):
|
||||
pointer_(reinterpret_cast<AccessType *>(ref.data())),
|
||||
layout_(ref.stride()[0] / AccessType::kElements) {
|
||||
|
||||
auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout();
|
||||
MatrixCoord lane_offset = lane_layout.inverse(lane_id);
|
||||
|
||||
pointer_ += layout_({
|
||||
lane_offset.row(),
|
||||
lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements)
|
||||
});
|
||||
}
|
||||
|
||||
/// Adds a pointer offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirectConv & add_pointer_offset(Index pointer_offset) {
|
||||
pointer_ += pointer_offset / AccessType::kElements;
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirectConv & add_tile_offset(TensorCoord const &tile_offset) {
|
||||
|
||||
pointer_ += layout_({
|
||||
tile_offset.row() * Shape::kRow,
|
||||
(tile_offset.column() * Shape::kColumn / int(AccessType::kElements))
|
||||
});
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
///< advances in units of whole tiles along the logical coordinate space of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirectConv & operator+=(TensorCoord const &tile_offset) {
|
||||
|
||||
add_tile_offset(tile_offset);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Store
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
|
||||
// original vector stores
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
AccessType * load_pointer_ = reinterpret_cast<AccessType *>(reinterpret_cast<uint8_t *>(pointer_) + base_smem_address_);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
|
||||
load_pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n];
|
||||
}
|
||||
}
|
||||
|
||||
/// Store
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store(Fragment const &frag) {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Load
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < Policy::kAccessesPerIteration; ++n) {
|
||||
frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)];
|
||||
}
|
||||
}
|
||||
|
||||
/// Load
|
||||
CUTLASS_HOST_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address){
|
||||
base_smem_address_ = address;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Template for reading and writing tiles of accumulators to shared memory
|
||||
template <typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
|
||||
typename ThreadOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename ThreadBlockOutputShape_, /// Size of the matrix to load (concept: TensorNHWC)
|
||||
typename Operator_, ///< matrix multi ply operation (concept: arch::Mma)
|
||||
typename Element_, ///< data type of element to be written
|
||||
typename Layout_, ///< target shared memory layout
|
||||
typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy)
|
||||
>
|
||||
class TileIteratorSimtDirect2dConv {
|
||||
public:
|
||||
using WarpShape = WarpShape_;
|
||||
using ThreadOutputShape = ThreadOutputShape_;
|
||||
using ThreadBlockOutputShape = ThreadBlockOutputShape_;
|
||||
using Operator = Operator_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
using MmaSimtPolicy = MmaSimtPolicy_;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>; ///< Tensor Reference object
|
||||
using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor
|
||||
using Index = typename TensorRef::Index;
|
||||
using LongIndex = typename TensorRef::LongIndex;
|
||||
|
||||
// Thread-level shape of a fragment
|
||||
using ThreadShape = MatrixShape<ThreadOutputShape::kNHW, ThreadOutputShape::kC>;
|
||||
|
||||
static_assert(!(ThreadShape::kColumn % MmaSimtPolicy::LaneMmaShape::kN),
|
||||
"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
|
||||
|
||||
using ThreadTileCount = MatrixShape<ThreadBlockOutputShape::kH / ThreadOutputShape::kH,
|
||||
ThreadBlockOutputShape::kW / ThreadOutputShape::kW>;
|
||||
|
||||
using Iterations =
|
||||
MatrixShape<ThreadShape::kRow, ThreadShape::kColumn / MmaSimtPolicy::LaneMmaShape::kN>;
|
||||
|
||||
/// This is the complete warp-level accumulator tile.
|
||||
using AccumulatorTile = typename Operator::FragmentC;
|
||||
|
||||
/// This is the fragment size produced by one access of the iterator.
|
||||
using Fragment = AccumulatorTile;
|
||||
|
||||
/// Padding quantity
|
||||
using Padding = MatrixShape<0, 0>;
|
||||
|
||||
private:
|
||||
// Storage type for accessing memory
|
||||
using AccessType = AlignedArray<Element, MmaSimtPolicy::LaneMmaShape::kN>;
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Internal pointer to memory
|
||||
AccessType *pointer_;
|
||||
|
||||
/// Internal layout object
|
||||
Layout layout_;
|
||||
|
||||
/// Base smem offset;
|
||||
Index base_smem_address_;
|
||||
|
||||
public:
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirect2dConv() : pointer_(nullptr) {}
|
||||
|
||||
/// Constructor from TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirect2dConv(TensorRef const &ref, unsigned thread_id, unsigned lane_id)
|
||||
: pointer_(reinterpret_cast<AccessType *>(ref.data())),
|
||||
layout_(ref.stride()[0] / AccessType::kElements) {
|
||||
|
||||
auto lane_layout = MmaSimtPolicy::get_lane_layout();
|
||||
|
||||
MatrixCoord lane_offset = lane_layout.inverse(lane_id);
|
||||
|
||||
// Get base HW offset of current threads
|
||||
const int threadgroup = thread_id / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC);
|
||||
const int base_p = (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH;
|
||||
const int base_q = (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW;
|
||||
|
||||
const int row_offset = base_p * ThreadBlockOutputShape::kW + base_q;
|
||||
|
||||
pointer_ += layout_(
|
||||
{row_offset,
|
||||
lane_offset.column() * MmaSimtPolicy::LaneMmaShape::kN / int(AccessType::kElements)});
|
||||
}
|
||||
|
||||
/// Adds a pointer offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorSimtDirect2dConv &add_pointer_offset(Index pointer_offset) {
|
||||
pointer_ += pointer_offset / AccessType::kElements;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Store
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
||||
AccessType *storer_pointer_ =
|
||||
reinterpret_cast<AccessType *>(reinterpret_cast<uint8_t *>(pointer_) + base_smem_address_);
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < ThreadOutputShape::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < ThreadOutputShape::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int col = 0; col < Iterations::kColumn; ++col) {
|
||||
int offset = (w + h * ThreadBlockOutputShape::kW) *
|
||||
(ThreadBlockOutputShape::kC / AccessType::kElements) +
|
||||
col;
|
||||
storer_pointer_[offset + pointer_offset / int(AccessType::kElements)] =
|
||||
frag_ptr[w + h * ThreadOutputShape::kW + col];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Store
|
||||
CUTLASS_HOST_DEVICE
|
||||
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) { base_smem_address_ = address; }
|
||||
};
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template for reading and writing tiles of accumulators to shared memory
|
||||
template <
|
||||
typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape)
|
||||
@ -482,6 +771,10 @@ public:
|
||||
return add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -228,6 +228,11 @@ public:
|
||||
TileIteratorTensorOp & operator++() {
|
||||
return add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -420,6 +425,11 @@ public:
|
||||
TileIteratorTensorOp & operator++() {
|
||||
return add_tile_offset({0, 1});
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -645,6 +655,11 @@ public:
|
||||
TileIteratorTensorOpCanonical & operator++() {
|
||||
return add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -304,6 +304,11 @@ public:
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -506,6 +511,11 @@ public:
|
||||
void store(Fragment const &frag) {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -697,6 +707,11 @@ public:
|
||||
void store(Fragment const &frag) {
|
||||
store_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -242,6 +242,11 @@ public:
|
||||
void load(Fragment const &frag) {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -419,6 +424,11 @@ public:
|
||||
void load(Fragment const &frag) {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -207,6 +207,12 @@ public:
|
||||
void load(Fragment &frag) const {
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
|
||||
|
||||
/// Set smem base address
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_smem_base_address(Index address) {
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -280,6 +280,36 @@ struct FastDivmod {
|
||||
unsigned int multiplier;
|
||||
unsigned int shift_right;
|
||||
|
||||
/// Find quotient and remainder using device-side intrinsics
|
||||
CUTLASS_HOST_DEVICE
|
||||
void fast_divmod(int& quotient, int& remainder, int dividend) const {
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
// Use IMUL.HI if divisor != 1, else simply copy the source.
|
||||
quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend;
|
||||
#else
|
||||
quotient = int((divisor != 1) ? int(((int64_t)dividend * multiplier) >> 32) >> shift_right : dividend);
|
||||
#endif
|
||||
|
||||
// The remainder.
|
||||
remainder = dividend - (quotient * divisor);
|
||||
}
|
||||
|
||||
/// For long int input
|
||||
CUTLASS_HOST_DEVICE
|
||||
void fast_divmod(int& quotient, int64_t& remainder, int64_t dividend) const {
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
// Use IMUL.HI if divisor != 1, else simply copy the source.
|
||||
quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend;
|
||||
#else
|
||||
quotient = int((divisor != 1) ? ((dividend * multiplier) >> 32) >> shift_right : dividend);
|
||||
#endif
|
||||
// The remainder.
|
||||
remainder = dividend - (quotient * divisor);
|
||||
}
|
||||
|
||||
|
||||
/// Construct the FastDivmod object, in host code ideally.
|
||||
///
|
||||
/// This precomputes some values based on the divisor and is computationally expensive.
|
||||
@ -288,17 +318,35 @@ struct FastDivmod {
|
||||
FastDivmod(): divisor(0), multiplier(0), shift_right(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
FastDivmod(int divisor_): divisor(divisor_) {
|
||||
find_divisor(multiplier, shift_right, divisor);
|
||||
FastDivmod(int divisor): divisor(divisor) {
|
||||
|
||||
if (divisor != 1) {
|
||||
unsigned int p = 31 + find_log2(divisor);
|
||||
unsigned m = unsigned(((1ull << p) + unsigned(divisor) - 1) / unsigned(divisor));
|
||||
|
||||
multiplier = m;
|
||||
shift_right = p - 32;
|
||||
} else {
|
||||
multiplier = 0;
|
||||
shift_right = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes integer division and modulus using precomputed values. This is computationally
|
||||
/// inexpensive.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(int "ient, int &remainder, int dividend) const {
|
||||
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
|
||||
fast_divmod(quotient, remainder, dividend);
|
||||
}
|
||||
|
||||
/// Computes integer division using precomputed values. This is computationally
|
||||
/// inexpensive.
|
||||
CUTLASS_HOST_DEVICE
|
||||
int div(int dividend) const {
|
||||
int quotient, remainder;
|
||||
fast_divmod(quotient, remainder, dividend);
|
||||
return quotient;
|
||||
}
|
||||
|
||||
/// Computes integer division and modulus using precomputed values. This is computationally
|
||||
/// inexpensive.
|
||||
@ -307,7 +355,7 @@ struct FastDivmod {
|
||||
CUTLASS_HOST_DEVICE
|
||||
int divmod(int &remainder, int dividend) const {
|
||||
int quotient;
|
||||
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
|
||||
fast_divmod(quotient, remainder, dividend);
|
||||
return quotient;
|
||||
}
|
||||
|
||||
@ -315,7 +363,7 @@ struct FastDivmod {
|
||||
/// inexpensive.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void operator()(int "ient, int64_t &remainder, int64_t dividend) const {
|
||||
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
|
||||
fast_divmod(quotient, remainder, dividend);
|
||||
}
|
||||
|
||||
/// Computes integer division and modulus using precomputed values. This is computationally
|
||||
@ -323,9 +371,14 @@ struct FastDivmod {
|
||||
CUTLASS_HOST_DEVICE
|
||||
int divmod(int64_t &remainder, int64_t dividend) const {
|
||||
int quotient;
|
||||
fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right);
|
||||
fast_divmod(quotient, remainder, dividend);
|
||||
return quotient;
|
||||
}
|
||||
|
||||
/// Returns the divisor when cast to integer
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator int() const { return divisor; }
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
1213
include/cutlass/float8.h
Normal file
1213
include/cutlass/float8.h
Normal file
File diff suppressed because it is too large
Load Diff
65
include/cutlass/floating_point_nvrtc.h
Normal file
65
include/cutlass/floating_point_nvrtc.h
Normal file
@ -0,0 +1,65 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Defines categories for floating point numbers for use in NVRTC-compiled code
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// All floating-point numbers can be put in one of these categories.
|
||||
enum {
|
||||
FP_NAN =
|
||||
# define FP_NAN 0
|
||||
FP_NAN,
|
||||
FP_INFINITE =
|
||||
# define FP_INFINITE 1
|
||||
FP_INFINITE,
|
||||
FP_ZERO =
|
||||
# define FP_ZERO 2
|
||||
FP_ZERO,
|
||||
FP_SUBNORMAL =
|
||||
# define FP_SUBNORMAL 3
|
||||
FP_SUBNORMAL,
|
||||
FP_NORMAL =
|
||||
# define FP_NORMAL 4
|
||||
FP_NORMAL
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
File diff suppressed because it is too large
Load Diff
@ -211,7 +211,7 @@ public:
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
|
||||
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
@ -349,7 +349,7 @@ public:
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
||||
CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Workspace
|
||||
|
||||
@ -765,6 +765,52 @@ struct DefaultGemmConfiguration<
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ElementC,
|
||||
typename ElementAccumulator>
|
||||
struct DefaultGemmConfiguration<arch::OpClassTensorOp, arch::Sm90, double,
|
||||
double, ElementC, ElementAccumulator> {
|
||||
|
||||
static int const kAlignmentA = 1;
|
||||
static int const kAlignmentB = 1;
|
||||
|
||||
using ThreadblockShape = GemmShape<128, 256, 64>;
|
||||
using WarpShape = GemmShape<64, 64, 64>;
|
||||
using InstructionShape = GemmShape<16, 8, 4>;
|
||||
static int const kStages = 3;
|
||||
|
||||
using EpilogueOutputOp = epilogue::thread::LinearCombination<
|
||||
ElementC, 128 / sizeof_bits<ElementC>::value, ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemmConfiguration<
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm90,
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>
|
||||
> {
|
||||
|
||||
static int const kAlignmentA = 1;
|
||||
static int const kAlignmentB = 1;
|
||||
|
||||
using ThreadblockShape = GemmShape<64, 64, 16>;
|
||||
using WarpShape = GemmShape<32, 32, 16>;
|
||||
using InstructionShape = GemmShape<16, 8, 4>;
|
||||
static int const kStages = 3;
|
||||
|
||||
using EpilogueOutputOp = epilogue::thread::LinearCombination<
|
||||
complex<double>, 1, complex<double>,
|
||||
complex<double>>;
|
||||
|
||||
using Operator = arch::OpMultiplyAddComplex;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
848
include/cutlass/gemm/device/ell_gemm.h
Normal file
848
include/cutlass/gemm/device/ell_gemm.h
Normal file
@ -0,0 +1,848 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a Block-Ell sparse gemm kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/ell_gemm.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_ell_gemm.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*! Blocked-Ell sparse gemm device-level operator. This is an interface to efficient CUTLASS
|
||||
Blocked-Ell kernels that may be invoked from host code.
|
||||
|
||||
The contributions of this class are:
|
||||
|
||||
1. At compile time, it maps data types and high-level structural parameters onto
|
||||
specific CUTLASS components.
|
||||
|
||||
2. At runtime, it maps logical arguments to Blocked-Ell problems to kernel parameters.
|
||||
|
||||
3. At runtime, it launches kernels on the device.
|
||||
|
||||
Example of a CUTLASS EllGemm operator is as follows:
|
||||
|
||||
//
|
||||
// Instantiate the CUTLASS EllGemm operator.
|
||||
//
|
||||
|
||||
cutlass::gemm::device::EllGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
float,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
cutlass::half_t, 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
|
||||
float, float>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
|
||||
4, // Stages
|
||||
128 / cutlass::sizeof_bits<cutlass::half_t>::value, // Alignment A
|
||||
128 / cutlass::sizeof_bits<cutlass::half_t>::value // Alignment B
|
||||
> ellgemm_op;
|
||||
|
||||
//
|
||||
// Launch the EllGemm operation on the device
|
||||
//
|
||||
|
||||
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format:
|
||||
a_rows - Rows in the sparse matrix.
|
||||
a_cols - Colums in the sparse matrix.
|
||||
BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in
|
||||
consecutive blocks, whose size is (a_rows * a_ell_num_columns)
|
||||
ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is
|
||||
(a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize)
|
||||
a_ell_blocksize - Size of the ELL-Blocks.
|
||||
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
|
||||
B - Input dense matrix whose size is (a_cols * n)
|
||||
C/D - Output dense matrix whose size is (a_rows * n)
|
||||
|
||||
cutlass::Status status = ellgemm_op({
|
||||
{a_rows, n, a_cols}, // GemmCoord problem_size
|
||||
{BlockedEllA, lda}, // TensorRef<cutlass::half_t, layout::RowMajor> ref_BlockedEllA
|
||||
{B, ldb}, // TensorRef<cutlass::half_t, layout::ColumnMajor> ref_B,
|
||||
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
|
||||
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
|
||||
ell_idx, // Blocked-ELL Column indices or ellColInd matrix (const int*)
|
||||
a_ell_num_columns, // Columns in the Blocked-Ellpack (ellValue) matrix (int)
|
||||
a_ell_blocksize, // Size of the ELL-Blocks (int)
|
||||
a_ell_base, // Base index of ellColInd (int) - Zero or One
|
||||
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
|
||||
});
|
||||
|
||||
A simplified view of the template is listed below.
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
|
||||
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||
/// supports the intended feature. The device kernel can be built
|
||||
/// targeting any SM larger than this number.
|
||||
typename ArchTag,
|
||||
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages
|
||||
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA,
|
||||
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB,
|
||||
|
||||
/// Supports split-K with serial reduction
|
||||
bool SplitKSerial,
|
||||
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
class EllGemm;
|
||||
*/
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassTensorOp,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_ = arch::Sm80,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ =
|
||||
typename threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse = true
|
||||
>
|
||||
class EllGemm {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static bool const kIsASparse = IsASparse;
|
||||
|
||||
/// Define the kernel
|
||||
using GemmKernel = typename kernel::DefaultEllGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
kIsASparse
|
||||
>::GemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size;
|
||||
TensorRef<ElementA const, LayoutA> ref_A;
|
||||
TensorRef<ElementB const, LayoutB> ref_B;
|
||||
TensorRef<ElementC const, LayoutC> ref_C;
|
||||
TensorRef<ElementC, LayoutC> ref_D;
|
||||
const int* ell_idx;
|
||||
int ell_ncol;
|
||||
int ell_blocksize;
|
||||
int ell_base_idx;
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C_,
|
||||
TensorRef<ElementC, LayoutC> ref_D_,
|
||||
const int* ell_idx_,
|
||||
int ell_ncol_,
|
||||
int ell_blocksize_,
|
||||
int ell_base_idx_,
|
||||
typename EpilogueOutputOp::Params epilogue_ =
|
||||
typename EpilogueOutputOp::Params(),
|
||||
int split_k_slices = 1
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
ell_idx(ell_idx_),
|
||||
ell_ncol(ell_ncol_),
|
||||
ell_blocksize(ell_blocksize_),
|
||||
ell_base_idx(ell_base_idx_),
|
||||
epilogue(epilogue_),
|
||||
split_k_slices(split_k_slices) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
EllGemm() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
Status status = GemmKernel::can_implement(
|
||||
args.problem_size,
|
||||
args.ref_A.non_const_ref(),
|
||||
args.ref_B.non_const_ref(),
|
||||
args.ref_C.non_const_ref(),
|
||||
args.ref_D
|
||||
);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{args.ell_blocksize,
|
||||
ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
tiled_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM;
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params{
|
||||
args.problem_size,
|
||||
grid_shape,
|
||||
args.ref_A.non_const_ref(),
|
||||
args.ref_B.non_const_ref(),
|
||||
args.ref_C.non_const_ref(),
|
||||
args.ref_D,
|
||||
args.ell_idx,
|
||||
args.ell_ncol,
|
||||
args.ell_blocksize,
|
||||
args.ell_base_idx,
|
||||
args.epilogue,
|
||||
static_cast<int *>(workspace)
|
||||
};
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{args.ell_blocksize, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
grid_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM;
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
return set(args, grid_shape, workspace);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A.reset(args.ref_A.non_const_ref().data());
|
||||
params_.ref_B.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_D.reset(args.ref_D.data());
|
||||
params_.output_op = args.epilogue;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization for column-major output exchanges problem size and operand.
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp_,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB,
|
||||
/// If true, kernel supports split-K as a serial reduction
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse>
|
||||
class EllGemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
|
||||
layout::ColumnMajor, // partially specialized on LayoutC
|
||||
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
|
||||
WarpShape_, InstructionShape_, EpilogueOutputOp_,
|
||||
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
|
||||
SplitKSerial, Operator_, IsASparse> {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = layout::ColumnMajor;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool const kIsASparse = false;
|
||||
|
||||
using UnderlyingOperator = EllGemm<
|
||||
ElementB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementA,
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
kAlignmentB,
|
||||
kAlignmentA,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
kIsASparse
|
||||
>;
|
||||
|
||||
using UnderlyingArguments = typename UnderlyingOperator::Arguments;
|
||||
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||
static int const kAlignmentC = UnderlyingOperator::kAlignmentC;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size;
|
||||
TensorRef<ElementA const, LayoutA> ref_A;
|
||||
TensorRef<ElementB const, LayoutB> ref_B;
|
||||
TensorRef<ElementC const, LayoutC> ref_C;
|
||||
TensorRef<ElementC, LayoutC> ref_D;
|
||||
const int* ell_idx;
|
||||
int ell_ncol;
|
||||
int ell_blocksize;
|
||||
int ell_base_idx;
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() { }
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C_,
|
||||
TensorRef<ElementC, LayoutC> ref_D_,
|
||||
const int* ell_idx_,
|
||||
int ell_ncol_,
|
||||
int ell_blocksize_,
|
||||
int ell_base_idx_,
|
||||
typename EpilogueOutputOp::Params epilogue_ =
|
||||
typename EpilogueOutputOp::Params(),
|
||||
int split_k_slices = 1
|
||||
):
|
||||
problem_size(problem_size_),
|
||||
ref_A(ref_A_),
|
||||
ref_B(ref_B_),
|
||||
ref_C(ref_C_),
|
||||
ref_D(ref_D_),
|
||||
ell_idx(ell_idx_),
|
||||
ell_ncol(ell_ncol_),
|
||||
ell_blocksize(ell_blocksize_),
|
||||
ell_base_idx(ell_base_idx_),
|
||||
epilogue(epilogue_),
|
||||
split_k_slices(split_k_slices) { }
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
UnderlyingOperator underlying_operator_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
EllGemm() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
return UnderlyingArguments(
|
||||
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
|
||||
{args.ref_B.data(), args.ref_B.stride(0)},
|
||||
{args.ref_A.data(), args.ref_A.stride(0)},
|
||||
{args.ref_C.data(), args.ref_C.stride(0)},
|
||||
{args.ref_D.data(), args.ref_D.stride(0)},
|
||||
args.ell_idx,
|
||||
args.ell_ncol,
|
||||
args.ell_blocksize,
|
||||
args.ell_base_idx,
|
||||
args.epilogue,
|
||||
args.split_k_slices
|
||||
);
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
tiled_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN;
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){
|
||||
// Initialize the Params structure
|
||||
return underlying_operator_.set(to_underlying_arguments(args), grid_shape, workspace);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
|
||||
{ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
grid_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN;
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
set(args, grid_shape, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
return underlying_operator_.run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -29,7 +29,7 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
\brief Template for a pipelined batch GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -58,6 +58,11 @@ namespace device {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!
|
||||
GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation
|
||||
(problem geometry and data references), it can be reused across different GEMM problems having the
|
||||
geometry. (Once initialized, details regarding problem geometry and references to workspace memory
|
||||
cannot be updated.)
|
||||
|
||||
The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
@ -109,7 +109,6 @@ public:
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
|
||||
using UnderlyingOperator = GemmUniversalBase<GemmKernel>;
|
||||
using Arguments = typename UnderlyingOperator::Arguments;
|
||||
|
||||
@ -160,10 +159,11 @@ public:
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
|
||||
/// remain the same.
|
||||
Status update(Arguments const &args) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||
return underlying_operator_.update(to_underlying_arguments(args));
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
|
||||
@ -28,15 +28,15 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
\brief The universal GEMM accommodates streamk, batched strided, and batched array variants.
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
//#include <limits>
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@ -44,7 +44,6 @@
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
@ -52,7 +51,7 @@
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -67,7 +66,7 @@ public:
|
||||
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
@ -83,7 +82,8 @@ public:
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
/// Numerical accumulation element type
|
||||
using ElementAccumulator = typename GemmKernel::Mma::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
@ -94,316 +94,285 @@ public:
|
||||
|
||||
protected:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
//
|
||||
// Device properties (uniform across all instances of the current thread)
|
||||
//
|
||||
|
||||
protected:
|
||||
// Device ordinal
|
||||
thread_local static int device_ordinal_;
|
||||
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) {
|
||||
/// Device SM count
|
||||
thread_local static int device_sms_;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
/// Kernel SM occupancy (in thread blocks)
|
||||
thread_local static int sm_occupancy_;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
/// Initialize static thread-local members for the thread's current device,
|
||||
/// if necessary.
|
||||
static Status init_device_props()
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()");
|
||||
|
||||
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
cudaError_t cudart_result;
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size) {
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBase() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
|
||||
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes =
|
||||
sizeof(ElementC) *
|
||||
size_t(args.batch_stride_D) *
|
||||
size_t(grid_tiled_shape.k());
|
||||
}
|
||||
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
// Get current device ordinal
|
||||
int current_ordinal;
|
||||
cudart_result = cudaGetDevice(¤t_ordinal);
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
// Done if matches the current static member
|
||||
if (current_ordinal == device_ordinal_) {
|
||||
// Already initialized
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
// Update SM count member
|
||||
cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal);
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const &args) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
// Update the kernel function's shared memory configuration for the current device
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
// Requires more than 48KB: configure for extended, dynamic shared memory
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
cudart_result = cudaFuncSetAttribute(
|
||||
Kernel2<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0) {
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes) {
|
||||
|
||||
if (!workspace) {
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm) {
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(
|
||||
args,
|
||||
grid_tiled_shape,
|
||||
gemm_k_size,
|
||||
static_cast<int *>(workspace)
|
||||
);
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10)) {
|
||||
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
cudart_result = cudaFuncSetAttribute(
|
||||
Kernel2<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100); // 100% shared memory
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
// Update SM occupancy member
|
||||
cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
|
||||
&sm_occupancy_,
|
||||
Kernel2<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
int(sizeof(typename GemmKernel::SharedStorage)),
|
||||
cudaOccupancyDisableCachingOverride);
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
|
||||
// Update device ordinal member on success
|
||||
device_ordinal_ = current_ordinal;
|
||||
|
||||
CUTLASS_TRACE_HOST(" "
|
||||
"device_ordinal: (" << device_ordinal_ << "), "
|
||||
"device_sms: (" << device_sms_ << "), "
|
||||
"sm_occupancy: (" << sm_occupancy_ << ")");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Instance data members
|
||||
//
|
||||
|
||||
/// Kernel parameters
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
|
||||
/// Initialize params member
|
||||
Status init_params(Arguments const &args)
|
||||
{
|
||||
// Initialize static device properties, if necessary
|
||||
Status result = init_device_props();
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Initialize params member
|
||||
params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
//---------------------------------------------------------------------------------------------
|
||||
// Stateless API
|
||||
//---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()");
|
||||
|
||||
// Initialize static kernel and device properties, if necessary.
|
||||
Status result = init_device_props();
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
dim3 grid = get_grid_shape(args);
|
||||
|
||||
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
|
||||
grid.z <= std::numeric_limits<uint16_t>::max()))
|
||||
{
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
|
||||
/// Returns the workspace size (in bytes) needed for the problem
|
||||
/// geometry expressed by these arguments
|
||||
static size_t get_workspace_size(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()");
|
||||
|
||||
// Initialize parameters from args
|
||||
GemmUniversalBase base;
|
||||
if (base.init_params(args) != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Get size from parameters
|
||||
size_t workspace_bytes = base.params_.get_workspace_size();
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
|
||||
/// Returns the grid extents in thread blocks to launch
|
||||
static dim3 get_grid_shape(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()");
|
||||
|
||||
// Initialize parameters from args
|
||||
GemmUniversalBase base;
|
||||
if (base.init_params(args) != Status::kSuccess) {
|
||||
return dim3(0,0,0);
|
||||
}
|
||||
|
||||
// Get dims from parameters
|
||||
dim3 grid_dims = base.params_.get_grid_dims();
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" tiled_shape: " << base.params_.get_tiled_shape() << "\n"
|
||||
<< " grid_dims: {" << grid_dims << "}");
|
||||
|
||||
return grid_dims;
|
||||
}
|
||||
|
||||
|
||||
/// Returns the maximum number of active thread blocks per multiprocessor
|
||||
static int maximum_active_blocks()
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
|
||||
|
||||
// Initialize static device properties, if necessary
|
||||
if (init_device_props() != Status::kSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
|
||||
return sm_occupancy_;
|
||||
}
|
||||
|
||||
|
||||
//---------------------------------------------------------------------------------------------
|
||||
// Stateful API
|
||||
//---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Initializes GEMM state from arguments and workspace memory
|
||||
Status initialize(
|
||||
Arguments const &args,
|
||||
void *workspace,
|
||||
cudaStream_t stream = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize parameters from args
|
||||
Status result = init_params(args);
|
||||
if (result != Status::kSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Assign and prepare workspace memory
|
||||
return params_.init_workspace(workspace, stream);
|
||||
}
|
||||
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
|
||||
/// remain the same.
|
||||
Status update(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");
|
||||
params_.update(args);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
dim3 grid = params_.get_grid_dims();
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
CUTLASS_TRACE_HOST(" "
|
||||
"grid: (" << grid << "), "
|
||||
"block: (" << block << "), "
|
||||
"SMEM: (" << smem_size << ")");
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block
|
||||
<< "), SMEM: " << smem_size << " bytes");
|
||||
Kernel2<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
cudaStream_t stream = nullptr)
|
||||
{
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
@ -412,6 +381,24 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Static initializers
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Device ordinal
|
||||
template <typename GemmKernel_>
|
||||
thread_local int GemmUniversalBase<GemmKernel_>::device_ordinal_ = -1;
|
||||
|
||||
/// Device SM count
|
||||
template <typename GemmKernel_>
|
||||
thread_local int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
|
||||
|
||||
/// Kernel SM occupancy (in thread blocks)
|
||||
template <typename GemmKernel_>
|
||||
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
|
||||
@ -28,8 +28,10 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief Template for a GEMM kernel that can broadcast bias vector in the
|
||||
epigloue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -45,7 +47,7 @@
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
|
||||
@ -97,7 +99,7 @@ template <
|
||||
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
|
||||
ElementC_, ElementAccumulator_, ElementAccumulator_,
|
||||
ElementC_, ElementC_, 16 / sizeof(ElementC_)>,
|
||||
ElementC_, ElementC_, 128 / cutlass::sizeof_bits<ElementC_>::value>,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
@ -123,7 +125,7 @@ template <
|
||||
>
|
||||
class GemmUniversalWithBroadcast :
|
||||
public GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithBroadcastV2<
|
||||
typename kernel::DefaultGemmWithBroadcast<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
@ -166,7 +168,7 @@ class GemmUniversalWithBroadcast :
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using Base = GemmUniversalBase<
|
||||
typename kernel::DefaultGemmWithBroadcastV2<
|
||||
typename kernel::DefaultGemmWithBroadcast<
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
TransformA,
|
||||
|
||||
@ -29,7 +29,8 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
\brief Template for a GEMM kernel that can reduce one of the input matrix
|
||||
into a vector along the K dimension.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
837
include/cutlass/gemm/kernel/default_ell_gemm.h
Normal file
837
include/cutlass/gemm/kernel/default_ell_gemm.h
Normal file
@ -0,0 +1,837 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Default kernel-level Blocked-Ell sparse gemm operators.
|
||||
This operator combines threadblock-scoped ELL MMA
|
||||
with the appropriate threadblock-scoped epilogue.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/wmma.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
|
||||
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
||||
|
||||
#include "cutlass/gemm/kernel/ell_gemm.h"
|
||||
#include "cutlass/gemm/threadblock/default_ell_mma.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse>
|
||||
struct DefaultEllGemm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
struct DefaultEllGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator, IsASparse> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
struct DefaultEllGemm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
IsASparse
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator
|
||||
>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse>
|
||||
struct DefaultEllGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB, ElementC,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape,
|
||||
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, IsASparse> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages, Operator,
|
||||
true>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Integer Matrix Multiply Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse>
|
||||
struct DefaultEllGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape,
|
||||
WarpShape, InstructionShape, EpilogueOutputOp,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, IsASparse> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape,
|
||||
InstructionShape, 2, Operator, true>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/// Partial specialization for Volta architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
struct DefaultEllGemm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm70,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
GemmShape<8, 8, 4>,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
IsASparse
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm70,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
GemmShape<8, 8, 4>,
|
||||
2,
|
||||
Operator
|
||||
>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for SIMT
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
struct DefaultEllGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
GemmShape<1, 1, 1>,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
IsASparse> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassSimt,
|
||||
arch::Sm50,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
GemmShape<1, 1, 1>,
|
||||
2,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
|
||||
static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
EpilogueOutputOp,
|
||||
kEpilogueElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
struct DefaultEllGemm<ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassSimt,
|
||||
arch::Sm80,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
GemmShape<1, 1, 1>,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
IsASparse> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm80,
|
||||
ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
|
||||
static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
EpilogueOutputOp,
|
||||
kEpilogueElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial,IsASparse>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Partial specialization for SIMT DP4A
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Layout type for C matrix operand
|
||||
typename LayoutC,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
struct DefaultEllGemm<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
|
||||
ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt,
|
||||
ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial,
|
||||
Operator, IsASparse> {
|
||||
using InstructionShape = GemmShape<1, 1, 4>;
|
||||
using ElementA = int8_t;
|
||||
using ElementB = int8_t;
|
||||
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
arch::OpClassSimt,
|
||||
arch::Sm50,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator
|
||||
>::ThreadblockMma;
|
||||
|
||||
static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
|
||||
static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
EpilogueOutputOp,
|
||||
kEpilogueElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Partial specialization for Wmma Gemm Kernel
|
||||
template <
|
||||
///< Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Sparse matrix is A or not
|
||||
bool IsASparse
|
||||
>
|
||||
struct DefaultEllGemm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
arch::OpClassWmmaTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
IsASparse> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultEllMma<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC,
|
||||
arch::OpClassWmmaTensorOp,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp<
|
||||
ThreadblockShape,
|
||||
typename Mma::Operator,
|
||||
kPartitionsK,
|
||||
EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::EllGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial, IsASparse>;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#endif //CUTLASS_ARCH_WMMA_ENABLED
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -135,6 +135,77 @@ template <
|
||||
struct DefaultGemm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB,
|
||||
/// Scatter result D by using an index array
|
||||
bool ScatterD,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout
|
||||
>
|
||||
struct DefaultGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, PermuteDLayout> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
|
||||
@ -119,6 +119,66 @@ struct DefaultGemmComplex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator
|
||||
// (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial
|
||||
>
|
||||
struct DefaultGemmComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape,
|
||||
WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
|
||||
@ -49,6 +49,7 @@
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal_streamk.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
|
||||
@ -227,12 +228,26 @@ struct DefaultGemmUniversal<
|
||||
PermuteDLayout
|
||||
>::GemmKernel;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::GemmUniversal<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
/// Universal kernel without StreamkFeature member type
|
||||
template <class SwizzleT, class Enable = void>
|
||||
class SelectBase :
|
||||
public kernel::GemmUniversal<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
SwizzleT>
|
||||
{};
|
||||
|
||||
/// Universal kernel with StreamkFeature member type
|
||||
template <class SwizzleT>
|
||||
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> :
|
||||
public kernel::GemmUniversalStreamk<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
SwizzleT>
|
||||
{};
|
||||
|
||||
/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
|
||||
using GemmKernel = SelectBase<ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -336,12 +351,26 @@ struct DefaultGemmUniversal<
|
||||
false
|
||||
>::GemmKernel;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::GemmUniversal<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
/// Universal kernel without StreamkFeature member type
|
||||
template <class SwizzleT, class Enable = void>
|
||||
class SelectBase :
|
||||
public kernel::GemmUniversal<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
SwizzleT>
|
||||
{};
|
||||
|
||||
/// Universal kernel with StreamkFeature member type
|
||||
template <class SwizzleT>
|
||||
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature> :
|
||||
public kernel::GemmUniversalStreamk<
|
||||
typename DefaultGemmKernel::Mma,
|
||||
typename DefaultGemmKernel::Epilogue,
|
||||
SwizzleT>
|
||||
{};
|
||||
|
||||
/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
|
||||
using GemmKernel = SelectBase<ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,242 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Defines a GEMM with Reduction based on an existing UniversalGemm kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h"
|
||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
typename Enable = void
|
||||
>
|
||||
struct DefaultGemmWithBroadcastV2 {
|
||||
|
||||
using GemmBase = typename DefaultGemmUniversal<
|
||||
ElementA_, LayoutA_, TransformA, kAlignmentA,
|
||||
ElementB_, LayoutB_, TransformB, kAlignmentB,
|
||||
ElementC_, LayoutC_, ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator
|
||||
>::GemmKernel;
|
||||
|
||||
// Replace epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOpV2<
|
||||
typename GemmBase::Epilogue::Shape,
|
||||
typename GemmBase::Epilogue::WarpMmaOperator,
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
// Compose the GEMM kernel
|
||||
using GemmKernel = GemmWithFusedEpilogueV2<
|
||||
typename GemmBase::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Parital specialization: ArchTag = cutlass::arch::Sm70
|
||||
///
|
||||
///
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
typename Enable
|
||||
>
|
||||
struct DefaultGemmWithBroadcastV2<
|
||||
ElementA_, LayoutA_, TransformA, kAlignmentA,
|
||||
ElementB_, LayoutB_, TransformB, kAlignmentB,
|
||||
ElementC_, LayoutC_,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
cutlass::arch::Sm70,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator,
|
||||
Enable
|
||||
> {
|
||||
|
||||
using GemmBase = typename DefaultGemmUniversal<
|
||||
ElementA_, LayoutA_, TransformA, kAlignmentA,
|
||||
ElementB_, LayoutB_, TransformB, kAlignmentB,
|
||||
ElementC_, LayoutC_, ElementAccumulator,
|
||||
OperatorClass,
|
||||
cutlass::arch::Sm70,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
Stages,
|
||||
Operator
|
||||
>::GemmKernel;
|
||||
|
||||
// Replace epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOpV2<
|
||||
typename GemmBase::Epilogue::Shape,
|
||||
typename GemmBase::Epilogue::WarpMmaOperator,
|
||||
GemmBase::Epilogue::kPartitionsK,
|
||||
ElementC_,
|
||||
typename EpilogueOutputOp::ElementT,
|
||||
ElementC_,
|
||||
EpilogueOutputOp,
|
||||
GemmBase::Epilogue::kElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
// Compose the GEMM kernel
|
||||
using GemmKernel = GemmWithFusedEpilogueV2<
|
||||
typename GemmBase::Mma,
|
||||
Epilogue,
|
||||
ThreadblockSwizzle
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -120,6 +120,84 @@ template <
|
||||
BlasMode BlasMode_ = BlasMode::kSymmetric>
|
||||
struct DefaultRank2K;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Fill Mode for C (kLower or kUpper)
|
||||
FillMode FillModeC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultRank2K<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC,layout::RowMajor, FillModeC,
|
||||
ElementAccumulator, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (A x BT)
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB, typename layout::LayoutTranspose<LayoutB>::type,
|
||||
kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (B x AT)
|
||||
using Mma2 = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementB, LayoutB,
|
||||
kAlignmentB,
|
||||
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
|
||||
kAlignmentA,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3<
|
||||
ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue;
|
||||
|
||||
/// Define the kernel-level Rank2K operator.
|
||||
using Rank2Kkernel = kernel::Rank2KUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, FillModeC, BlasMode::kSymmetric>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -163,6 +163,170 @@ template <>
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture complex datatype (symmetric)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Fill Mode for C (kLower or kUpper)
|
||||
FillMode FillModeC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial>
|
||||
struct DefaultRank2KComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages,
|
||||
TransformA, TransformB, Operator, SplitKSerial, BlasMode::kSymmetric> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kSymmetric;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (A x B^T)
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
TransformA, TransformB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (B x A^T)
|
||||
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementB, LayoutB,
|
||||
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
TransformA, TransformB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
|
||||
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
|
||||
|
||||
/// Define the kernel-level Rank2K operator.
|
||||
using Rank2Kkernel = kernel::Rank2KUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, FillModeC, kBlasMode>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture complex datatype (hermitian)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Fill Mode for C (kLower or kUpper)
|
||||
FillMode FillModeC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial>
|
||||
struct DefaultRank2KComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages,
|
||||
TransformA, TransformB, Operator, SplitKSerial, BlasMode::kHermitian> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kHermitian;
|
||||
|
||||
// Complex transform for input A and B matrices (function on input layout)
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = TransformB;
|
||||
|
||||
using TransposedComplexTransform = detail::Rank2KTransposedComplexTransform<
|
||||
LayoutA, LayoutB,
|
||||
TransformA, TransformB,
|
||||
kBlasMode>;
|
||||
|
||||
// Complex transform on operandA and operandB (function of blas3 computation)
|
||||
static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA;
|
||||
static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (A x B^H)
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, typename layout::LayoutTranspose<LayoutB>::type,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (B x A^H)
|
||||
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementB, LayoutB,
|
||||
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
|
||||
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
|
||||
|
||||
/// Define the kernel-level Rank2K operator.
|
||||
using Rank2Kkernel = kernel::Rank2KUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, FillModeC, kBlasMode>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture complex datatype (symmetric)
|
||||
|
||||
@ -114,6 +114,68 @@ template <
|
||||
BlasMode BlasMode_ = BlasMode::kSymmetric>
|
||||
struct DefaultRankK;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Fill Mode for C (kLower or kUpper)
|
||||
FillMode FillModeC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultRankK<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementC,layout::RowMajor, FillModeC,
|
||||
ElementAccumulator, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (A x AT)
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMma<
|
||||
ElementA, LayoutA,
|
||||
kAlignmentA,
|
||||
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
|
||||
kAlignmentA,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue;
|
||||
|
||||
/// Define the kernel-level Rank2 operator.
|
||||
using RankKkernel = kernel::RankKUniversal<Mma, Epilogue, ThreadblockSwizzle, FillModeC>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -155,6 +155,140 @@ template <>
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture complex datatype (symmetric)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Fill Mode for C (kLower or kUpper)
|
||||
FillMode FillModeC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial>
|
||||
struct DefaultRankKComplex<
|
||||
ElementA, LayoutA, ElementC,
|
||||
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages,
|
||||
TransformA, Operator, SplitKSerial, BlasMode::kSymmetric> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kSymmetric;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (A x B^T)
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
TransformA, TransformA, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
|
||||
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
|
||||
|
||||
/// Define the kernel-level RankK operator.
|
||||
using RankKkernel = kernel::RankKUniversal<Mma, Epilogue, ThreadblockSwizzle, FillModeC>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture complex datatype (hermitian)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Fill Mode for C (kLower or kUpper)
|
||||
FillMode FillModeC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial>
|
||||
struct DefaultRankKComplex<
|
||||
ElementA, LayoutA, ElementC,
|
||||
layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages,
|
||||
TransformA, Operator, SplitKSerial, BlasMode::kHermitian> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kHermitian;
|
||||
|
||||
// Complex transform for input A and B matrices (function on input layout)
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
|
||||
using TransposedComplexTransform = detail::RankKTransposedComplexTransform<
|
||||
LayoutA,
|
||||
TransformA,
|
||||
kBlasMode>;
|
||||
|
||||
// Complex transform on operandA and operandB (function of blas3 computation)
|
||||
static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA;
|
||||
static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate (A x A^H)
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementA, typename layout::LayoutTranspose<LayoutA>::type,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3<
|
||||
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue;
|
||||
|
||||
/// Define the kernel-level RankK operator.
|
||||
using RankKkernel = kernel::RankKUniversal<Mma, Epilogue, ThreadblockSwizzle, FillModeC>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture complex datatype (symmetric)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
|
||||
@ -123,6 +123,101 @@ template <
|
||||
BlasMode BlasMode_ = BlasMode::kSymmetric>
|
||||
struct DefaultSymm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Side Mode for A (kLeft or kRight)
|
||||
SideMode kSideModeA,
|
||||
/// Fill Mode for A (kLower or kUpper)
|
||||
FillMode kFillModeA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultSymm<
|
||||
ElementA, LayoutA, kSideModeA, kFillModeA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC,layout::RowMajor,
|
||||
ElementAccumulator, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
|
||||
/// Define the threadblock-scoped triagular matrix multiply-accumulate
|
||||
/// TRMM - with diagonal: alpha * A * B or alpha * B * A
|
||||
static const DiagType kDiagTypeMma1 = DiagType::kNonUnit;
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultTrmm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
kSideModeA, kFillModeA, kDiagTypeMma1,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
Stages, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the threadblock-scoped triagular matrix multiply-accumulate
|
||||
/// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT
|
||||
static const DiagType kDiagTypeMma2 = DiagType::kZero;
|
||||
using LayoutAMma2 = typename platform::conditional<
|
||||
(kSideModeA == SideMode::kLeft),
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
LayoutA
|
||||
>::type;
|
||||
using LayoutBMma2 = typename platform::conditional<
|
||||
(kSideModeA == SideMode::kLeft),
|
||||
LayoutB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type
|
||||
>::type;
|
||||
using Mma2 = typename cutlass::gemm::threadblock::DefaultTrmm<
|
||||
ElementA, LayoutAMma2, kAlignmentA,
|
||||
ElementB, LayoutBMma2, kAlignmentB,
|
||||
kSideModeA, InvertFillMode<kFillModeA>::mode, kDiagTypeMma2,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
Stages, Operator>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level SYMM/HEMM operator.
|
||||
using SymmKernel = kernel::SymmUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, kSideModeA, kFillModeA>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -221,7 +316,6 @@ struct DefaultSymm<
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -117,6 +117,199 @@ struct DefaultSymmComplex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture complex datatype (symmetric)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Side Mode for A (kLeft or kRight)
|
||||
SideMode kSideModeA,
|
||||
/// Fill Mode for A (kLower or kUpper)
|
||||
FillMode kFillModeA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial>
|
||||
struct DefaultSymmComplex<
|
||||
ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages,
|
||||
Operator, SplitKSerial, BlasMode::kSymmetric> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kSymmetric;
|
||||
// Complex Transform don't appply to A or B for SYMM
|
||||
static ComplexTransform const TransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const TransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Define the threadblock-scoped triagular matrix multiply-accumulate
|
||||
/// TRMM - with diagonal: alpha * A * B or alpha * B * A
|
||||
static const DiagType kDiagTypeMma1 = DiagType::kNonUnit;
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
kSideModeA, kFillModeA, kDiagTypeMma1,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
Stages, TransformA, TransformB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the threadblock-scoped triagular matrix multiply-accumulate
|
||||
/// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT
|
||||
static const DiagType kDiagTypeMma2 = DiagType::kZero;
|
||||
using LayoutAMma2 = typename platform::conditional<
|
||||
(kSideModeA == SideMode::kLeft),
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
LayoutA
|
||||
>::type;
|
||||
using LayoutBMma2 = typename platform::conditional<
|
||||
(kSideModeA == SideMode::kLeft),
|
||||
LayoutB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type
|
||||
>::type;
|
||||
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
|
||||
ElementA, LayoutAMma2,
|
||||
ElementB, LayoutBMma2,
|
||||
kSideModeA, InvertFillMode<kFillModeA>::mode, kDiagTypeMma2,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
Stages, TransformA, TransformB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
|
||||
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator>::Epilogue;
|
||||
|
||||
/// Define the kernel-level Symm operator.
|
||||
using SymmKernel = kernel::SymmUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, kSideModeA, kFillModeA>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture complex datatype (hermitian)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Side Mode for A (kLeft or kRight)
|
||||
SideMode kSideModeA,
|
||||
/// Fill Mode for A (kLower or kUpper)
|
||||
FillMode kFillModeA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial>
|
||||
struct DefaultSymmComplex<
|
||||
ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages,
|
||||
Operator, SplitKSerial, BlasMode::kHermitian> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kHermitian;
|
||||
|
||||
|
||||
/// Define the threadblock-scoped triagular matrix multiply-accumulate
|
||||
/// TRMM - with diagonal: alpha * A * B or alpha * B * A
|
||||
static const DiagType kDiagTypeMma1 = DiagType::kNonUnit;
|
||||
static ComplexTransform const TransformAMma1 = ComplexTransform::kNone;
|
||||
static ComplexTransform const TransformBMma1 = ComplexTransform::kNone;
|
||||
using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
kSideModeA, kFillModeA, kDiagTypeMma1,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
Stages, TransformAMma1, TransformBMma1, Operator, BlasMode::kHermitian>::ThreadblockMma;
|
||||
|
||||
/// Define the threadblock-scoped triagular matrix multiply-accumulate
|
||||
/// TRMM - withOUT diagonal - with conjugate transpose: alpha * AT * B or alpha * B * AT
|
||||
static const DiagType kDiagTypeMma2 = DiagType::kZero;
|
||||
using LayoutAMma2 = typename platform::conditional<
|
||||
(kSideModeA == SideMode::kLeft),
|
||||
typename layout::LayoutTranspose<LayoutA>::type,
|
||||
LayoutA
|
||||
>::type;
|
||||
using LayoutBMma2 = typename platform::conditional<
|
||||
(kSideModeA == SideMode::kLeft),
|
||||
LayoutB,
|
||||
typename layout::LayoutTranspose<LayoutB>::type
|
||||
>::type;
|
||||
static ComplexTransform const TransformAMma2 = (kSideModeA == SideMode::kLeft) ?
|
||||
ComplexTransform::kConjugate : ComplexTransform::kNone;
|
||||
static ComplexTransform const TransformBMma2 = (kSideModeA == SideMode::kLeft) ?
|
||||
ComplexTransform::kNone : ComplexTransform::kConjugate;
|
||||
|
||||
using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
|
||||
ElementA, LayoutAMma2,
|
||||
ElementB, LayoutBMma2,
|
||||
kSideModeA, InvertFillMode<kFillModeA>::mode, kDiagTypeMma2,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape,
|
||||
Stages, TransformAMma2, TransformBMma2, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
|
||||
ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator>::Epilogue;
|
||||
|
||||
/// Define the kernel-level Symm operator.
|
||||
using SymmKernel = kernel::SymmUniversal<Mma1, Mma2, Epilogue, ThreadblockSwizzle, kSideModeA, kFillModeA>;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture complex datatype (symmetric)
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
@ -310,7 +503,6 @@ struct DefaultSymmComplex<
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -124,6 +124,76 @@ struct DefaultTrmm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Side Mode for the kernel
|
||||
SideMode kSideMode,
|
||||
/// Fill Mode for the triangular matrix
|
||||
FillMode kFillMode,
|
||||
/// Diag Type for the triangular matrix
|
||||
DiagType kDiagType,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultTrmm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
kSideMode, kFillMode, kDiagType, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
Operator> {
|
||||
|
||||
/// Define the threadblock-scoped triagular matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultTrmm<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
kSideMode, kFillMode, kDiagType,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level TRMM operator.
|
||||
using TrmmKernel = kernel::TrmmUniversal<Mma, Epilogue, ThreadblockSwizzle, kSideMode, kFillMode, kDiagType>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
|
||||
@ -122,6 +122,74 @@ struct DefaultTrmmComplex;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Hopper Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Side Mode for the kernel
|
||||
SideMode kSideMode,
|
||||
/// Fill Mode for the triangular matrix
|
||||
FillMode kFillMode,
|
||||
/// Diag Type for the triangular matrix
|
||||
DiagType kDiagType,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Multiply-add operator
|
||||
// (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial
|
||||
>
|
||||
struct DefaultTrmmComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB,
|
||||
kSideMode, kFillMode, kDiagType,
|
||||
ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm90, ThreadblockShape, WarpShape, InstructionShape,
|
||||
EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using Mma = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex<
|
||||
ElementA, LayoutA, ElementB, LayoutB,
|
||||
kSideMode, kFillMode, kDiagType,
|
||||
ElementAccumulator,layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape,
|
||||
WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp<
|
||||
ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount, Operator>::Epilogue;
|
||||
|
||||
/// Define the kernel-level TRMM operator.
|
||||
using TrmmKernel = kernel::TrmmUniversal<Mma, Epilogue, ThreadblockSwizzle, kSideMode, kFillMode, kDiagType>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Ampere Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
|
||||
830
include/cutlass/gemm/kernel/ell_gemm.h
Normal file
830
include/cutlass/gemm/kernel/ell_gemm.h
Normal file
@ -0,0 +1,830 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Template for a Block-Ell sparse gemm kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/ell_iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
bool IsASparse ///! If true, A is sparse matrix
|
||||
>
|
||||
struct EllGemm {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using OutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename OutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations;
|
||||
int gemm_k_size;
|
||||
const int* ell_idx;
|
||||
int ell_ncol;
|
||||
int ell_blocksize;
|
||||
int ell_base_idx;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
||||
const int* ell_idx,
|
||||
int ell_ncol,
|
||||
int ell_blocksize,
|
||||
int ell_base_idx,
|
||||
typename OutputOp::Params output_op = typename OutputOp::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(ref_A.layout()),
|
||||
ref_A(ref_A),
|
||||
params_B(ref_B.layout()),
|
||||
ref_B(ref_B),
|
||||
params_C(ref_C.layout()),
|
||||
ref_C(ref_C),
|
||||
params_D(ref_D.layout()),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
ell_idx(ell_idx),
|
||||
ell_ncol(ell_ncol),
|
||||
ell_blocksize(ell_blocksize),
|
||||
ell_base_idx(ell_base_idx)
|
||||
{
|
||||
|
||||
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage {
|
||||
union{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
typename cutlass::transform::threadblock::ell::SharedStorage ell;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
EllGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D) {
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kM - 1 ) / Mma::Shape::kM;
|
||||
int ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block;
|
||||
int tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// skip computation if matrix is 0
|
||||
if (params.ell_ncol > 0) {
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
ell_block_offset_m * params.ell_blocksize
|
||||
+ tile_offset_m * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
};
|
||||
|
||||
int ell_idx_start =
|
||||
(threadblock_tile_offset.m() / tile_in_ell_block) *
|
||||
(params.ell_ncol / params.ell_blocksize);
|
||||
const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]);
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(
|
||||
params.problem_size.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
problem_size_k = min(problem_size_k, params.ell_ncol);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations =
|
||||
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
params.ref_B.data(),
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
// Define coef for ELL index depending on LayoutB
|
||||
int ell_stride = iterator_B.get_stride();
|
||||
|
||||
typename cutlass::transform::threadblock::ell::Iterator ell_iterator(
|
||||
shared_storage.ell,
|
||||
ell_idx_ptr,
|
||||
params.ell_blocksize,
|
||||
params.ell_base_idx,
|
||||
Mma::Shape::kK,
|
||||
problem_size_k,
|
||||
ell_stride,
|
||||
thread_idx
|
||||
);
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
||||
// check if index computations can be skipped
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8);
|
||||
constexpr bool is_multiple_alignment =
|
||||
(kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1);
|
||||
const bool is_specialized_blocksize =
|
||||
((params.ell_blocksize) & (params.ell_blocksize-1)) == 0
|
||||
&& params.ell_blocksize >= Mma::Shape::kK;
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
if ((is_double || is_multiple_alignment) && is_specialized_blocksize) {
|
||||
mma.operator()<true, true>(
|
||||
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
|
||||
}
|
||||
else {
|
||||
mma.operator()<true, false>(
|
||||
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
|
||||
}
|
||||
}
|
||||
} // if (params.ell_ncols > 0)
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
OutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block;
|
||||
tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block;
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
ell_block_offset_m * params.ell_blocksize
|
||||
+ tile_offset_m * Mma::Shape::kM,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
//avoid out of bounds
|
||||
MatrixCoord threadblock_extent(
|
||||
min(params.problem_size.m(),
|
||||
ell_block_offset_m * params.ell_blocksize
|
||||
+ min((tile_offset_m + 1) * Mma::Shape::kM, params.ell_blocksize)),
|
||||
min(params.problem_size.n(),
|
||||
(threadblock_tile_offset.n()+1) * Mma::Shape::kN)
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
params.ref_C.data(),
|
||||
threadblock_extent,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
params.ref_D.data(),
|
||||
threadblock_extent,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// B is Sparse
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
>
|
||||
struct EllGemm<Mma_, Epilogue_, ThreadblockSwizzle_, SplitKSerial, false> {
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using OutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename OutputOp::Params output_op;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations;
|
||||
int gemm_k_size;
|
||||
const int* ell_idx;
|
||||
int ell_ncol;
|
||||
int ell_blocksize;
|
||||
int ell_base_idx;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
||||
const int* ell_idx,
|
||||
int ell_ncol,
|
||||
int ell_blocksize,
|
||||
int ell_base_idx,
|
||||
typename OutputOp::Params output_op = typename OutputOp::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
problem_size(problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
params_A(ref_A.layout()),
|
||||
ref_A(ref_A),
|
||||
params_B(ref_B.layout()),
|
||||
ref_B(ref_B),
|
||||
params_C(ref_C.layout()),
|
||||
ref_C(ref_C),
|
||||
params_D(ref_D.layout()),
|
||||
ref_D(ref_D),
|
||||
output_op(output_op),
|
||||
ell_idx(ell_idx),
|
||||
ell_ncol(ell_ncol),
|
||||
ell_blocksize(ell_blocksize),
|
||||
ell_base_idx(ell_base_idx)
|
||||
{
|
||||
|
||||
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage {
|
||||
union{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
typename cutlass::transform::threadblock::ell::SharedStorage ell;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
EllGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A,
|
||||
typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D) {
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout,
|
||||
layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kN - 1 ) / Mma::Shape::kN;
|
||||
int ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block;
|
||||
int tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// skip computation if matrix is 0
|
||||
if (params.ell_ncol > 0) {
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
ell_block_offset_n * params.ell_blocksize
|
||||
+ tile_offset_n * Mma::Shape::kN,
|
||||
};
|
||||
|
||||
int ell_idx_start =
|
||||
(threadblock_tile_offset.n() / tile_in_ell_block) *
|
||||
(params.ell_ncol / params.ell_blocksize);
|
||||
const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]);
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(
|
||||
params.problem_size.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
problem_size_k = min(problem_size_k, params.ell_ncol);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations =
|
||||
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A,
|
||||
params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k},
|
||||
thread_idx,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B,
|
||||
params.ref_B.data(),
|
||||
{problem_size_k, params.problem_size.n()},
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
// Define coef for ELL index depending on LayoutA
|
||||
int ell_stride = iterator_A.get_stride();
|
||||
|
||||
typename cutlass::transform::threadblock::ell::Iterator ell_iterator(
|
||||
shared_storage.ell,
|
||||
ell_idx_ptr,
|
||||
params.ell_blocksize,
|
||||
params.ell_base_idx,
|
||||
Mma::Shape::kK,
|
||||
problem_size_k,
|
||||
ell_stride,
|
||||
thread_idx
|
||||
);
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
||||
// check if index computations can be skipped
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8);
|
||||
constexpr bool is_multiple_alignment =
|
||||
(kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1);
|
||||
const bool is_specialized_blocksize =
|
||||
((params.ell_blocksize) & (params.ell_blocksize-1)) == 0
|
||||
&& params.ell_blocksize >= Mma::Shape::kK;
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
if ((is_double || is_multiple_alignment) && is_specialized_blocksize) {
|
||||
mma.operator()<false, true>(
|
||||
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
|
||||
}
|
||||
else {
|
||||
mma.operator()<false, false>(
|
||||
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
|
||||
}
|
||||
}
|
||||
} // if (params.ell_ncols > 0)
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
OutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block;
|
||||
tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block;
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
ell_block_offset_n * params.ell_blocksize
|
||||
+ tile_offset_n * Mma::Shape::kN
|
||||
);
|
||||
|
||||
//avoid out of bounds
|
||||
MatrixCoord threadblock_extent(
|
||||
min(params.problem_size.m(),
|
||||
(threadblock_tile_offset.m()+1) * Mma::Shape::kM),
|
||||
min(params.problem_size.n(),
|
||||
ell_block_offset_n * params.ell_blocksize
|
||||
+ min((tile_offset_n + 1) * Mma::Shape::kN, params.ell_blocksize))
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
params.ref_C.data(),
|
||||
threadblock_extent,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
params.ref_D.data(),
|
||||
threadblock_extent,
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -315,13 +315,6 @@ public:
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
|
||||
@ -50,11 +50,22 @@ namespace kernel {
|
||||
|
||||
namespace detail {
|
||||
// Helper for correctly representing problem sizes in grouped kernels
|
||||
template <bool Transposed>
|
||||
template <
|
||||
typename ThreadblockShape,
|
||||
bool Transposed
|
||||
>
|
||||
struct GemmGroupedProblemSizeHelper {
|
||||
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) {
|
||||
return cutlass::gemm::GemmCoord(
|
||||
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN),
|
||||
1);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {
|
||||
if (kTransposed) {
|
||||
@ -77,7 +88,7 @@ template <typename ThreadblockShape,
|
||||
int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct GemmGroupedProblemVisitor : public GroupedProblemVisitor<
|
||||
detail::GemmGroupedProblemSizeHelper<Transposed>,
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
@ -85,7 +96,7 @@ struct GemmGroupedProblemVisitor : public GroupedProblemVisitor<
|
||||
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<Transposed>;
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using Params = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
@ -335,13 +335,6 @@ public:
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
@ -104,16 +105,12 @@ public:
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
struct Arguments : UniversalArgumentsBase
|
||||
{
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
void const * ptr_A;
|
||||
@ -132,7 +129,6 @@ public:
|
||||
int64_t batch_stride_gamma;
|
||||
int64_t batch_stride_beta;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
@ -161,14 +157,13 @@ public:
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_var(nullptr), ptr_mean(nullptr),
|
||||
ptr_gamma(nullptr), ptr_beta(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr) {}
|
||||
ptr_scatter_D_indices(nullptr)
|
||||
{}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
@ -202,31 +197,27 @@ public:
|
||||
typename LayoutC::Stride stride_d,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
int const *ptr_scatter_D_indices = nullptr)
|
||||
:
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
ptr_var(ptr_var), ptr_mean(ptr_mean),
|
||||
ptr_gamma(ptr_gamma), ptr_beta(ptr_beta),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean),
|
||||
batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta),
|
||||
lda(0), ldb(0), ldc(0), ldd(0),
|
||||
ld_var(0), ld_mean(0),
|
||||
ld_gamma(0), ld_beta(0),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
stride_var(stride_var), stride_mean(stride_mean),
|
||||
stride_gamma(stride_gamma), stride_beta(stride_beta),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
lda = 0;
|
||||
ldb = 0;
|
||||
ldc = 0;
|
||||
ldd = 0;
|
||||
ld_var = 0;
|
||||
ld_mean = 0;
|
||||
ld_gamma = 0;
|
||||
ld_beta = 0;
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
@ -260,23 +251,22 @@ public:
|
||||
typename LayoutC::Stride::LongIndex ldd,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
int const *ptr_scatter_D_indices = nullptr)
|
||||
:
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
ptr_var(ptr_var), ptr_mean(ptr_mean),
|
||||
ptr_gamma(ptr_gamma), ptr_beta(ptr_beta),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean),
|
||||
batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ld_var(ld_var), ld_mean(ld_mean),
|
||||
ld_gamma(ld_gamma), ld_beta(ld_beta),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices)
|
||||
{
|
||||
stride_a = make_Coord(lda);
|
||||
stride_b = make_Coord(ldb);
|
||||
stride_c = make_Coord(ldc);
|
||||
@ -286,7 +276,7 @@ public:
|
||||
stride_gamma = make_Coord(ld_gamma);
|
||||
stride_beta = make_Coord(ld_beta);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
@ -303,17 +293,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
struct Params : UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>
|
||||
{
|
||||
using ParamsBase = UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
@ -321,10 +324,6 @@ public:
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_var;
|
||||
@ -341,65 +340,30 @@ public:
|
||||
int64_t batch_stride_gamma;
|
||||
int64_t batch_stride_beta;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
int * ptr_scatter_D_indices;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_var(nullptr),
|
||||
ptr_mean(nullptr),
|
||||
ptr_gamma(nullptr),
|
||||
ptr_beta(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_var(0),
|
||||
batch_stride_mean(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr),
|
||||
semaphore(nullptr) { }
|
||||
/// Default constructor
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
/// Constructor
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
Arguments const &args, /// GEMM application arguments
|
||||
int device_sms, /// Number of SMs on the device
|
||||
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
|
||||
:
|
||||
ParamsBase(args, device_sms, sm_occupancy),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
output_op(args.epilogue),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_var(const_cast<void *>(args.ptr_var)),
|
||||
@ -415,19 +379,15 @@ public:
|
||||
batch_stride_gamma(args.batch_stride_gamma),
|
||||
batch_stride_beta(args.batch_stride_beta),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
||||
{}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_var = const_cast<void *>(args.ptr_var);
|
||||
@ -441,22 +401,13 @@ public:
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_var = args.batch_stride_var;
|
||||
batch_stride_mean = args.batch_stride_mean;
|
||||
batch_stride_gamma = args.batch_stride_gamma;
|
||||
batch_stride_beta = args.batch_stride_beta;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
@ -466,12 +417,9 @@ public:
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmLayernormMainloopFusion() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
@ -555,12 +503,23 @@ public:
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
public:
|
||||
|
||||
return 0;
|
||||
//
|
||||
// Device-only API
|
||||
//
|
||||
|
||||
// Factory invocation
|
||||
CUTLASS_DEVICE
|
||||
static void invoke(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage)
|
||||
{
|
||||
GemmLayernormMainloopFusion op;
|
||||
op(params, shared_storage);
|
||||
}
|
||||
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -105,16 +106,12 @@ public:
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
struct Arguments : UniversalArgumentsBase
|
||||
{
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
void const * ptr_A_real;
|
||||
@ -144,17 +141,13 @@ public:
|
||||
int64_t batch_stride_B_imag;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_C_imag;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_D_imag;
|
||||
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
Arguments() :
|
||||
ptr_A_real(nullptr),
|
||||
ptr_A_imag(nullptr),
|
||||
ptr_B_real(nullptr),
|
||||
@ -163,7 +156,7 @@ public:
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr)
|
||||
{ }
|
||||
{}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
@ -194,11 +187,9 @@ public:
|
||||
int64_t batch_stride_C = 0,
|
||||
int64_t batch_stride_C_imag = 0,
|
||||
int64_t batch_stride_D = 0,
|
||||
int64_t batch_stride_D_imag = 0
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
int64_t batch_stride_D_imag = 0)
|
||||
:
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue(epilogue),
|
||||
ptr_A_real(ptr_A_real),
|
||||
ptr_A_imag(ptr_A_imag),
|
||||
@ -222,10 +213,8 @@ public:
|
||||
batch_stride_B_imag(batch_stride_B_imag),
|
||||
batch_stride_C(batch_stride_C),
|
||||
batch_stride_C_imag(batch_stride_C_imag),
|
||||
batch_stride_D(batch_stride_D),
|
||||
batch_stride_D_imag(batch_stride_D_imag) {
|
||||
|
||||
}
|
||||
batch_stride_D_imag(batch_stride_D_imag)
|
||||
{}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
@ -243,16 +232,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
struct Params : UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>
|
||||
{
|
||||
using ParamsBase = UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
typename Mma::IteratorA::Params params_A_real;
|
||||
typename Mma::IteratorA::Params params_A_imag;
|
||||
typename Mma::IteratorB::Params params_B_real;
|
||||
@ -264,10 +267,6 @@ public:
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void * ptr_A_real;
|
||||
void * ptr_A_imag;
|
||||
void * ptr_B_real;
|
||||
@ -278,54 +277,28 @@ public:
|
||||
void * ptr_D_imag;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_A_imag;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_B_imag;
|
||||
int64_t batch_stride_C;
|
||||
|
||||
int64_t batch_stride_A_imag;
|
||||
int64_t batch_stride_B_imag;
|
||||
int64_t batch_stride_C_imag;
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_D_imag;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
swizzle_log_tile(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A_real(nullptr),
|
||||
ptr_A_imag(nullptr),
|
||||
ptr_B_real(nullptr),
|
||||
ptr_B_imag(nullptr),
|
||||
ptr_C_real(nullptr),
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_A_imag(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_B_imag(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_C_imag(0),
|
||||
batch_stride_D(0),
|
||||
batch_stride_D_imag(0),
|
||||
semaphore(nullptr) { }
|
||||
/// Default constructor
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
/// Constructor
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
Arguments const &args, /// GEMM application arguments
|
||||
int device_sms, /// Number of SMs on the device
|
||||
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
|
||||
:
|
||||
ParamsBase(args, device_sms, sm_occupancy),
|
||||
params_A_real(args.lda_real),
|
||||
params_A_imag(args.lda_imag),
|
||||
params_B_real(args.ldb_real),
|
||||
@ -335,9 +308,6 @@ public:
|
||||
params_D_real(args.ldd_real),
|
||||
params_D_imag(args.ldd_imag),
|
||||
output_op(args.epilogue),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A_real(const_cast<void *>(args.ptr_A_real)),
|
||||
ptr_A_imag(const_cast<void *>(args.ptr_A_imag)),
|
||||
ptr_B_real(const_cast<void *>(args.ptr_B_real)),
|
||||
@ -347,21 +317,32 @@ public:
|
||||
ptr_D_real(args.ptr_D_real),
|
||||
ptr_D_imag(args.ptr_D_imag),
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_A_imag(args.batch_stride_A_imag),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_B_imag(args.batch_stride_B_imag),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_A_imag(args.batch_stride_A_imag),
|
||||
batch_stride_B_imag(args.batch_stride_B_imag),
|
||||
batch_stride_C_imag(args.batch_stride_C_imag),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
batch_stride_D_imag(args.batch_stride_D_imag),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
batch_stride_D_imag(args.batch_stride_D_imag)
|
||||
{}
|
||||
|
||||
/// Returns the workspace size (in bytes) needed for this problem geometry
|
||||
size_t get_workspace_size() const
|
||||
{
|
||||
size_t workspace_bytes = ParamsBase::get_workspace_size();
|
||||
if (this->mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
// Double the size returned by the base class because we need to
|
||||
// accumulate two ElementC components
|
||||
workspace_bytes *= 2;
|
||||
}
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
ptr_A_real = const_cast<void *>(args.ptr_A_real);
|
||||
ptr_A_imag = const_cast<void *>(args.ptr_A_imag);
|
||||
|
||||
@ -374,21 +355,11 @@ public:
|
||||
ptr_D_real = const_cast<void *>(args.ptr_D_real);
|
||||
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_A_imag = args.batch_stride_A_imag;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_B_imag = args.batch_stride_B_imag;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_C_imag = args.batch_stride_C_imag;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
batch_stride_D_imag = args.batch_stride_D_imag;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
@ -398,15 +369,12 @@ public:
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmPlanarComplex() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
static Status can_implement(Arguments const &args)
|
||||
{
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
@ -440,12 +408,23 @@ public:
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
public:
|
||||
|
||||
return 0;
|
||||
//
|
||||
// Device-only API
|
||||
//
|
||||
|
||||
// Factory invocation
|
||||
CUTLASS_DEVICE
|
||||
static void invoke(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage)
|
||||
{
|
||||
GemmPlanarComplex op;
|
||||
op(params, shared_storage);
|
||||
}
|
||||
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -105,16 +106,12 @@ public:
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
struct Arguments : UniversalArgumentsBase
|
||||
{
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
int const *ptr_M;
|
||||
@ -142,15 +139,11 @@ public:
|
||||
typename LayoutC::Stride::Index ldd_real;
|
||||
typename LayoutC::Stride::Index ldd_imag;
|
||||
|
||||
int64_t batch_stride_D; // unused
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kArray),
|
||||
batch_count(1),
|
||||
ptr_M(nullptr),
|
||||
ptr_N(nullptr),
|
||||
ptr_K(nullptr),
|
||||
@ -161,9 +154,8 @@ public:
|
||||
ptr_C_real(nullptr),
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr),
|
||||
batch_stride_D(0)
|
||||
{ }
|
||||
ptr_D_imag(nullptr)
|
||||
{}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
@ -188,11 +180,9 @@ public:
|
||||
typename LayoutC::Stride::Index ldc_real,
|
||||
typename LayoutC::Stride::Index ldc_imag,
|
||||
typename LayoutC::Stride::Index ldd_real,
|
||||
typename LayoutC::Stride::Index ldd_imag
|
||||
):
|
||||
mode(GemmUniversalMode::kArray),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
typename LayoutC::Stride::Index ldd_imag)
|
||||
:
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue(epilogue),
|
||||
ptr_M(ptr_M),
|
||||
ptr_N(ptr_N),
|
||||
@ -212,10 +202,8 @@ public:
|
||||
ldc_real(ldc_real),
|
||||
ldc_imag(ldc_imag),
|
||||
ldd_real(ldd_real),
|
||||
ldd_imag(ldd_imag),
|
||||
batch_stride_D(0) {
|
||||
|
||||
}
|
||||
ldd_imag(ldd_imag)
|
||||
{}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
@ -232,15 +220,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
struct Params : UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>
|
||||
{
|
||||
using ParamsBase = UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
typename Mma::IteratorA::Params params_A_real;
|
||||
typename Mma::IteratorA::Params params_A_imag;
|
||||
typename Mma::IteratorB::Params params_B_real;
|
||||
@ -249,11 +252,9 @@ public:
|
||||
typename Epilogue::OutputTileIterator::Params params_C_imag;
|
||||
typename Epilogue::OutputTileIterator::Params params_D_real;
|
||||
typename Epilogue::OutputTileIterator::Params params_D_imag;
|
||||
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
int batch_count;
|
||||
|
||||
int const *ptr_M;
|
||||
int const *ptr_N;
|
||||
int const *ptr_K;
|
||||
@ -268,35 +269,19 @@ public:
|
||||
void * const * ptr_D_imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
batch_count(0),
|
||||
swizzle_log_tile(0),
|
||||
ptr_M(nullptr),
|
||||
ptr_N(nullptr),
|
||||
ptr_K(nullptr),
|
||||
ptr_A_real(nullptr),
|
||||
ptr_A_imag(nullptr),
|
||||
ptr_B_real(nullptr),
|
||||
ptr_B_imag(nullptr),
|
||||
ptr_C_real(nullptr),
|
||||
ptr_C_imag(nullptr),
|
||||
ptr_D_real(nullptr),
|
||||
ptr_D_imag(nullptr) { }
|
||||
/// Default constructor
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
/// Constructor
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size = 0, // ignored
|
||||
void *workspace = nullptr // ignored
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
Arguments const &args, /// GEMM application arguments
|
||||
int device_sms, /// Number of SMs on the device
|
||||
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
|
||||
:
|
||||
ParamsBase(args, device_sms, sm_occupancy),
|
||||
ptr_M(args.ptr_M),
|
||||
ptr_N(args.ptr_N),
|
||||
ptr_K(args.ptr_K),
|
||||
@ -309,7 +294,6 @@ public:
|
||||
params_D_real(args.ldd_real),
|
||||
params_D_imag(args.ldd_imag),
|
||||
output_op(args.epilogue),
|
||||
batch_count(args.batch_count),
|
||||
ptr_A_real(args.ptr_A_real),
|
||||
ptr_A_imag(args.ptr_A_imag),
|
||||
ptr_B_real(args.ptr_B_real),
|
||||
@ -317,14 +301,13 @@ public:
|
||||
ptr_C_real(args.ptr_C_real),
|
||||
ptr_C_imag(args.ptr_C_imag),
|
||||
ptr_D_real(args.ptr_D_real),
|
||||
ptr_D_imag(args.ptr_D_imag) {
|
||||
|
||||
}
|
||||
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
ptr_D_imag(args.ptr_D_imag)
|
||||
{}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
ptr_M = args.ptr_M;
|
||||
ptr_N = args.ptr_N;
|
||||
ptr_K = args.ptr_K;
|
||||
@ -345,6 +328,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
@ -354,12 +338,9 @@ public:
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmPlanarComplexArray() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
@ -396,12 +377,24 @@ public:
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
public:
|
||||
|
||||
//
|
||||
// Device-only API
|
||||
//
|
||||
|
||||
// Factory invocation
|
||||
CUTLASS_DEVICE
|
||||
static void invoke(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage)
|
||||
{
|
||||
GemmPlanarComplexArray op;
|
||||
op(params, shared_storage);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
@ -37,12 +37,12 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/params_universal_base.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
@ -55,7 +55,7 @@ namespace kernel {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
@ -101,16 +101,12 @@ public:
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
struct Arguments : UniversalArgumentsBase
|
||||
{
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
void const * ptr_A;
|
||||
@ -121,7 +117,6 @@ public:
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename LayoutA::Stride stride_a;
|
||||
typename LayoutB::Stride stride_b;
|
||||
@ -140,14 +135,13 @@ public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
|
||||
Arguments():
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr) {}
|
||||
ptr_scatter_D_indices(nullptr)
|
||||
{}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
@ -169,23 +163,22 @@ public:
|
||||
typename LayoutC::Stride stride_d,
|
||||
int const *ptr_gather_A_indices = nullptr,
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
int const *ptr_scatter_D_indices = nullptr)
|
||||
:
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices)
|
||||
{
|
||||
lda = 0;
|
||||
ldb = 0;
|
||||
ldc = 0;
|
||||
ldd = 0;
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(
|
||||
@ -209,26 +202,26 @@ public:
|
||||
int const *ptr_gather_B_indices = nullptr,
|
||||
int const *ptr_scatter_D_indices = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd),
|
||||
ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices),
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices) {
|
||||
ptr_scatter_D_indices(ptr_scatter_D_indices)
|
||||
{
|
||||
stride_a = make_Coord(lda);
|
||||
stride_b = make_Coord(ldb);
|
||||
stride_c = make_Coord(ldc);
|
||||
stride_d = make_Coord(ldd);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns arguments for the transposed problem
|
||||
Arguments transposed_problem() const {
|
||||
Arguments transposed_problem() const
|
||||
{
|
||||
Arguments args(*this);
|
||||
|
||||
|
||||
std::swap(args.problem_size.m(), args.problem_size.n());
|
||||
std::swap(args.ptr_A, args.ptr_B);
|
||||
std::swap(args.lda, args.ldb);
|
||||
@ -240,27 +233,36 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
struct Params : UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>
|
||||
{
|
||||
using ParamsBase = UniversalParamsBase<
|
||||
ThreadblockSwizzle,
|
||||
ThreadblockShape,
|
||||
ElementA,
|
||||
ElementB,
|
||||
ElementC>;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
@ -270,59 +272,30 @@ public:
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
int * ptr_gather_A_indices;
|
||||
int * ptr_gather_B_indices;
|
||||
int * ptr_scatter_D_indices;
|
||||
|
||||
int *semaphore;
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
swizzle_log_tile(0),
|
||||
params_A(0),
|
||||
params_B(0),
|
||||
params_C(0),
|
||||
params_D(0),
|
||||
batch_count(0),
|
||||
gemm_k_size(0),
|
||||
mode(cutlass::gemm::GemmUniversalMode::kGemm),
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
batch_stride_A(0),
|
||||
batch_stride_B(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0),
|
||||
ptr_gather_A_indices(nullptr),
|
||||
ptr_gather_B_indices(nullptr),
|
||||
ptr_scatter_D_indices(nullptr),
|
||||
semaphore(nullptr) { }
|
||||
/// Default constructor
|
||||
Params() = default;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
/// Constructor
|
||||
Params(
|
||||
Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
int gemm_k_size,
|
||||
void *workspace = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
Arguments const &args, /// GEMM application arguments
|
||||
int device_sms, /// Number of SMs on the device
|
||||
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
|
||||
:
|
||||
ParamsBase(args, device_sms, sm_occupancy),
|
||||
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
||||
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
||||
params_C(args.ldc ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldc) : args.stride_c),
|
||||
params_D(args.ldd ? make_Coord_with_padding<LayoutC::kStrideRank>(args.ldd) : args.stride_d),
|
||||
output_op(args.epilogue),
|
||||
mode(args.mode),
|
||||
batch_count(args.batch_count),
|
||||
gemm_k_size(gemm_k_size),
|
||||
ptr_A(const_cast<void *>(args.ptr_A)),
|
||||
ptr_B(const_cast<void *>(args.ptr_B)),
|
||||
ptr_C(const_cast<void *>(args.ptr_C)),
|
||||
@ -330,19 +303,18 @@ public:
|
||||
batch_stride_A(args.batch_stride_A),
|
||||
batch_stride_B(args.batch_stride_B),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D),
|
||||
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
||||
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices)),
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices)),
|
||||
semaphore(static_cast<int *>(workspace)) {
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
||||
{}
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr) {
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
|
||||
// Update input/output pointers
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
ptr_B = const_cast<void *>(args.ptr_B);
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
@ -352,37 +324,28 @@ public:
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_D = args.batch_stride_D;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
semaphore = static_cast<int *>(workspace);
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
// Host dispatch API
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmUniversal() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size) {
|
||||
|
||||
cutlass::gemm::GemmCoord const & problem_size)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||
@ -462,15 +425,30 @@ public:
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const &args,
|
||||
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
||||
|
||||
return 0;
|
||||
public:
|
||||
|
||||
//
|
||||
// Device-only API
|
||||
//
|
||||
|
||||
// Factory invocation
|
||||
CUTLASS_DEVICE
|
||||
static void invoke(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage)
|
||||
{
|
||||
GemmUniversal op;
|
||||
op(params, shared_storage);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
void operator()(
|
||||
Params const ¶ms,
|
||||
SharedStorage &shared_storage)
|
||||
{
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
@ -677,7 +655,7 @@ public:
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
1126
include/cutlass/gemm/kernel/gemm_universal_streamk.h
Normal file
1126
include/cutlass/gemm/kernel/gemm_universal_streamk.h
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user