314 lines
12 KiB
C++
314 lines
12 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
* provided that the following conditions are met:
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
* conditions and the following disclaimer.
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
* provided with the distribution.
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
* permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*! \file
|
|
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/array.h"
|
|
#include "cutlass/aligned_buffer.h"
|
|
#include "cutlass/numeric_conversion.h"
|
|
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/matrix_shape.h"
|
|
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/gemm/threadblock/mma_base.h"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace conv {
|
|
namespace threadblock {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
|
template <
|
|
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
|
typename Shape_,
|
|
/// Iterates over tiles of A operand in global memory
|
|
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
|
typename IteratorA_,
|
|
/// Iterates over tiles of A operand in shared memory
|
|
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
|
typename SmemIteratorA_,
|
|
/// Iterates over tiles of B operand in global memory
|
|
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
|
typename IteratorB_,
|
|
/// Iterates over tiles of B operand in shared memory
|
|
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
|
typename SmemIteratorB_,
|
|
/// Data type of accumulator matrix
|
|
typename ElementC_,
|
|
/// Data type of accumulator matrix
|
|
typename LayoutC_,
|
|
/// Policy describing tuning details (concept: MmaPolicy)
|
|
typename Policy_,
|
|
/// Transformation applied to A operand
|
|
typename TransformA_ = NumericArrayConverter<
|
|
typename SmemIteratorA_::Element,
|
|
typename IteratorA_::Element,
|
|
IteratorA_::Fragment::kElements>,
|
|
///
|
|
/// Transformation applied to A operand
|
|
typename TransformB_ = NumericArrayConverter<
|
|
typename SmemIteratorB_::Element,
|
|
typename IteratorB_::Element,
|
|
IteratorB_::Fragment::kElements>,
|
|
/// Used for partial specialization
|
|
typename Enable = bool
|
|
>
|
|
class ImplicitGemmPipelined : public gemm::threadblock::MmaBase<Shape_, Policy_, 2> {
|
|
public:
|
|
|
|
///< Base class
|
|
using Base = gemm::threadblock::MmaBase<Shape_, Policy_, 2>;
|
|
|
|
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
|
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
|
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
|
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
|
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
|
using Policy = Policy_; ///< Policy describing tuning details
|
|
|
|
using SmemIteratorA = SmemIteratorA_;
|
|
using SmemIteratorB = SmemIteratorB_;
|
|
|
|
using TransformA = TransformA_;
|
|
using TransformB = TransformB_;
|
|
|
|
//
|
|
// Dependent types
|
|
//
|
|
|
|
/// Fragment of operand A loaded from global memory
|
|
using FragmentA = typename IteratorA::Fragment;
|
|
|
|
/// Fragment of operand B loaded from global memory
|
|
using FragmentB = typename IteratorB::Fragment;
|
|
|
|
/// Fragment of accumulator tile
|
|
using FragmentC = typename Policy::Operator::FragmentC;
|
|
|
|
/// Warp-level Mma
|
|
using Operator = typename Policy::Operator;
|
|
|
|
/// Obtain the arch tag from the warp-level operator
|
|
using ArchTag = typename Policy::Operator::ArchTag;
|
|
|
|
/// Complex transform on A operand
|
|
static ComplexTransform const kTransformA = Operator::kTransformA;
|
|
|
|
/// Complex transform on B operand
|
|
static ComplexTransform const kTransformB = Operator::kTransformB;
|
|
|
|
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
|
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
|
|
|
private:
|
|
|
|
using WarpFragmentA = typename Operator::FragmentA;
|
|
using WarpFragmentB = typename Operator::FragmentB;
|
|
|
|
protected:
|
|
|
|
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
|
SmemIteratorA smem_iterator_A_;
|
|
|
|
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
|
SmemIteratorB smem_iterator_B_;
|
|
|
|
public:
|
|
|
|
/// Construct from tensor references
|
|
CUTLASS_DEVICE
|
|
ImplicitGemmPipelined(
|
|
typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
|
int thread_idx, ///< ID within the threadblock
|
|
int warp_idx, ///< ID of warp
|
|
int lane_idx ///< ID of each thread within a warp
|
|
):
|
|
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
|
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
|
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
|
|
|
// Compute warp location within threadblock tile by mapping the warp_id to
|
|
// three coordinates:
|
|
// _m: the warp's position within the threadblock along the M dimension
|
|
// _n: the warp's position within the threadblock along the N dimension
|
|
// _k: the warp's position within the threadblock along the K dimension
|
|
|
|
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
|
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
|
|
|
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
|
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
|
|
|
// Add per-warp offsets in units of warp-level tiles
|
|
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
|
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
|
}
|
|
|
|
/// Perform a threadblock-scoped matrix multiply-accumulate
|
|
CUTLASS_DEVICE
|
|
void operator()(
|
|
int gemm_k_iterations, ///< number of iterations of the mainloop
|
|
FragmentC &accum, ///< destination accumulator tile
|
|
IteratorA iterator_A, ///< iterator over A operand in global memory
|
|
IteratorB iterator_B, ///< iterator over B operand in global memory
|
|
FragmentC const &src_accum, ///< source accumulator tile
|
|
TransformA transform_A = TransformA(), ///< transformation applied to A fragment
|
|
TransformB transform_B = TransformB()) { ///< transformation applied to B fragment
|
|
|
|
//
|
|
// Prologue
|
|
//
|
|
|
|
// Perform accumulation in the 'd' output operand
|
|
accum = src_accum;
|
|
|
|
FragmentA tb_frag_A;
|
|
FragmentB tb_frag_B;
|
|
|
|
tb_frag_A.clear();
|
|
tb_frag_B.clear();
|
|
|
|
// The last kblock is loaded in the prolog
|
|
iterator_A.load(tb_frag_A);
|
|
iterator_B.load(tb_frag_B);
|
|
|
|
++iterator_A;
|
|
++iterator_B;
|
|
|
|
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
|
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
|
|
|
++this->smem_iterator_A_;
|
|
++this->smem_iterator_B_;
|
|
|
|
__syncthreads();
|
|
|
|
// Pair of fragments used to overlap shared memory loads and math instructions
|
|
WarpFragmentA warp_frag_A[2];
|
|
WarpFragmentB warp_frag_B[2];
|
|
|
|
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
|
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
|
|
|
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
|
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
|
|
|
++this->warp_tile_iterator_A_;
|
|
++this->warp_tile_iterator_B_;
|
|
|
|
Operator warp_mma;
|
|
|
|
int smem_write_stage_idx = 1;
|
|
|
|
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
|
// shared memory loads (which have the tighest latency requirement).
|
|
|
|
//
|
|
// Mainloop
|
|
//
|
|
|
|
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
|
CUTLASS_GEMM_LOOP
|
|
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
|
//
|
|
// Loop over GEMM K dimension
|
|
//
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
|
|
|
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
|
// as the case may be.
|
|
|
|
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
|
|
|
// Write fragments to shared memory
|
|
this->smem_iterator_A_.store(transform_A(tb_frag_A));
|
|
|
|
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
|
|
|
__syncthreads();
|
|
|
|
++this->smem_iterator_A_;
|
|
++this->smem_iterator_B_;
|
|
|
|
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
|
if (smem_write_stage_idx == 1) {
|
|
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
|
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
|
}
|
|
else {
|
|
this->warp_tile_iterator_A_.add_tile_offset(
|
|
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
|
this->warp_tile_iterator_B_.add_tile_offset(
|
|
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
|
|
0});
|
|
}
|
|
|
|
smem_write_stage_idx ^= 1;
|
|
}
|
|
|
|
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
|
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
|
|
|
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
|
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
|
|
|
|
++this->warp_tile_iterator_A_;
|
|
++this->warp_tile_iterator_B_;
|
|
|
|
if (warp_mma_k == 0) {
|
|
|
|
iterator_A.load(tb_frag_A);
|
|
iterator_B.load(tb_frag_B);
|
|
|
|
++iterator_A;
|
|
++iterator_B;
|
|
}
|
|
|
|
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
|
|
warp_frag_B[warp_mma_k % 2], accum);
|
|
}
|
|
}
|
|
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace threadblock
|
|
} // namespace gemm
|
|
} // namespace cutlass
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|